更改日志#

建议在此处查看:此处。有关 Pallas API 的特定更改,请参阅Pallas Changelog

JAX 遵循基于工作量的版本控制;有关此以及 JAX API 兼容性策略的讨论,请参阅API 兼容性。有关 Python 和 NumPy 版本支持策略,请参阅Python 和 NumPy 版本支持策略

未发布#

  • 更改

    • jax.lax.linalg.eigh() 现在接受一个 implementation 参数来选择 QR (CPU/GPU)、Jacobi (GPU/TPU) 和 QDWH (TPU) 实现。 EighImplementation 枚举从 jax.lax.linalg 公开导出。

JAX 0.8.0 (2025年10月15日)#

  • 重大更改

    • JAX 将默认的 jax.pmap 实现更改为基于 jax.jitjax.shard_map 实现的。 jax.pmap 处于维护模式,我们鼓励所有新代码直接使用 jax.shard_map。有关更多信息,请参阅迁移指南

    • jax.experimental.shard_map.shard_mapauto= 参数已被移除。这意味着 jax.experimental.shard_map.shard_map 不再支持嵌套。如果要嵌套 shard_map 调用,请使用 jax.shard_map

    • JAX 不再允许将支持 __jax_array__ 的对象直接传递给(例如)jit 编译的函数。请先在它们上调用 jax.numpy.asarray

    • jax.numpy.cov() 现在返回 NaN 用于空数组(#32305),并且在单行设计矩阵方面与 NumPy 2.2 的行为一致(#32308)。

    • JAX 不再接受 Array 值,而是在需要 dtype 值的地方。请先调用这些值的 .dtype

    • 已移除已弃用的函数 jax.interpreters.mlir.custom_call()

    • 已移除 jax.utiljax.extend.ffijax.experimental.host_callback 模块。这些模块中的所有公共 API 在 v0.7.0 或更早版本中已被弃用并移除。

    • 已移除已弃用的符号 jax.custom_derivatives.custom_jvp_call_jaxpr_p

    • jax.experimental.multihost_utils.process_allgather 当输入是 jax.Array 且非完全可寻址且 tiled=False 时会引发错误。要修复此问题,请将 tiled=True 传递给您的 process_allgather 调用。

    • jax.experimental.compilation_cache 中,已移除已弃用的符号 is_initializedinitialize_cache

    • 已移除已弃用的函数 jax.interpreters.xla.canonicalize_dtype()

    • jaxlib.hlo_helpers 已被移除。请使用 jax.ffi 代替。

    • 选项 jax_cpu_enable_gloo_collectives 已被移除。请使用 jax_cpu_collectives_implementation 代替。

    • 已移除 jax.numpy.percentile()jax.numpy.quantile() 中先前已弃用的 interpolation 参数;请使用 method 代替。

    • 已移除 JAX 内部的 for_loop 原语。现在 jax.lax.fori_loop() 直接支持其功能,即在循环体内读写 ref。如果您需要帮助更新代码,请提交 bug。

    • jax.numpy.trimzeros() 现在会因非一维输入而报错。

    • jax.numpy.sum() 和其他归约操作的 where 参数现在要求为布尔值。非布尔值自 JAX v0.5.0 起已导致 DeprecationWarning

    • 已移除 {mod} jax.dlpack, {mod} jax.errors, {mod} jax.lib.xla_bridge, {mod} jax.lib.xla_client, 和 {mod} jax.lib.xla_extension 中的已弃用函数。

    • jax.interpreters.mlir.dense_bool_array 已移除。请使用 MLIR API 来构造属性。

  • 更改

    • jax.numpy.linalg.eig() 现在返回一个命名元组(带有 eigenvalueseigenvectors 属性),而不是一个普通元组。

    • jax.grad()jax.vjp() 现在始终将主元舍入到 float32,除非启用了 float64 模式。

    • jax.dlpack.from_dlpack() 现在接受具有非默认布局的数组,例如转置的数组。

    • NVIDIA GPU 上的默认非对称特征值分解现在使用 cusolver。通过 jax.lax.linalg.eig() 的新 implementation 参数,仍然可以使用 magma 和 LAPACK 实现(#27265)。 use_magma 参数现在已弃用,改用 implementation

    • jax.numpy.trim_zeros() 现在遵循 NumPy 2.2,支持多维输入。

  • 弃用

JAX 0.7.2 (2025年9月16日)#

  • 重大更改

    • jax.dlpack.from_dlpack() 不再接受 DLPack 胶囊。此行为已被弃用并已移除。该函数必须使用实现 __dlpack____dlpack_device__ 的数组进行调用。

  • 更改

    • 现在支持的最低 NumPy 版本为 2.0。由于 NumPy 2.0 支持需要 SciPy 1.13,因此现在支持的最低 SciPy 版本为 1.13。

    • JAX 现在在其内部 jaxpr 表示中将常量表示为 TypedNdArray,这是一个私有的 JAX 类型,可模拟 numpy.ndarray。此类型可能会通过 custom_jvp 规则等暴露给用户,并可能破坏使用 isinstance(x, np.ndarray) 的代码。如果这破坏了您的代码,您可以使用 np.asarray(x) 将这些数组转换为经典 NumPy 数组。

  • 错误修复

    • arr.view(dtype=None) 现在返回不变的数组,与 NumPy 的语义匹配。之前它返回具有 float 类型的数组。

    • jax.random.randint 现在为 8 位和 16 位整数类型生成更均匀分布的分布(#27742)。要恢复先前的有偏行为,您可以暂时将 jax_safer_randint 配置设置为 False,但请注意,这是一个临时配置,将在未来版本中移除。

  • 弃用

    • jax2tf.convertenable_xlanative_serialization 参数已被弃用,将在 JAX 的未来版本中移除。这些参数用于 jax2tf 与非原生序列化,而这部分功能已被移除。

    • 设置配置状态 jax_pmap_no_rank_reductionFalse 已被弃用。默认情况下,jax_pmap_no_rank_reduction 将设置为 True,并且 jax.pmap 分片不会降低其秩,保持与其包含数组相同的秩。

JAX 0.7.1 (2025年8月20日)#

  • 新功能

    • JAX 现在提供 Python 3.14 和 3.14t 轮子。

    • JAX 现在在 Mac 上提供 Python 3.13t 和 3.14t 轮子。之前我们只在 Linux 上提供免费线程构建。

  • 更改

    • 公开了 jax.set_mesh,它充当全局设置器和上下文管理器。移除了 jax.sharding.use_mesh,改用 jax.set_mesh

    • JAX 现在使用 CUDA 12.9 构建。所有 CUDA 12.1 或更高版本仍然受支持。

    • jax.lax.dot() 现在通过可选的 dimension_numbers 参数实现通用点积。

  • 弃用

    • jax.lax.zeros_like_array() 已被弃用。请改用 jax.numpy.zeros_like()

    • 尝试导入 jax.experimental.host_callback 现在会导致 DeprecationWarning,并将在 JAX v0.8.0 中导致 ImportError。自 JAX 0.4.35 版本以来,其 API 已引发 NotImplementedError

    • jax.lax.dot() 中,按位置传递 precisionpreferred_element_type 参数已被弃用。请改为通过显式关键字传递它们。

    • 已从 jax.interpreters.adjax.interpreters.batchingjax.interpreters.partial_eval 中弃用了几十个内部 API;它们很少或根本不在 JAX 本身之外使用,而且大多数都被弃用而没有公共替换。

JAX 0.7.0 (2025年7月22日)#

  • 新功能

  • 重大更改

    • JAX 正在默认从 GSPMD 迁移到 Shardy。有关更多信息,请参阅迁移指南

    • JAX 自动微分正在默认切换为使用直接线性化(而不是通过 JVP 和部分求值实现线性化)。有关更多信息,请参阅迁移指南

    • jax.stages.OutInfo 已被 jax.ShapeDtypeStruct 替换。

    • jax.jit() 现在要求 fun 按位置传递,其他参数按关键字传递。否则将在 v0.7.x 中引发错误。这在 v0.6.x 中引发了 DeprecationWarning。

    • 最低 Python 版本现在是 3.11。3.11 将保持最低支持版本,直到 2026 年 7 月。

    • 布局 API 重命名

      • Layout.layout.input_layouts.output_layouts 已重命名为 Format.format.input_formats.output_formats

      • DeviceLocalLayout.device_local_layout 已重命名为 Layout.layout

    • jax.experimental.shard 模块已被删除,所有 API 已移至 jax.sharding 端点。因此,请使用 jax.sharding.reshardjax.sharding.auto_axesjax.sharding.explicit_axes 而不是它们的实验性端点。

    • lax.infeedlax.outfeed 已被移除,此前已在 JAX 0.6 中弃用。 Device 对象上的 transfer_to_infeedtransfer_from_outfeed 方法也被移除。

    • jax.extend.core.primitives.pjit_p 原语已重命名为 jit_p,其 name 属性已从 "pjit" 更改为 "jit"。这会影响 jaxprs 的字符串表示。此原语不再从 jax.experimental.pjit 模块导出。

    • (未记录的)函数 jax.extend.backend.add_clear_backends_callback 已被移除。用户应改用 jax.extend.backend.register_backend_cache

    • out_sharding 参数已添加到 x.at[y].setx.at[y].add。先前的传播操作数分片行为已被移除。如果您需要保留先前的行为,请使用 x.at[y].set/add(z, out_sharding=jax.typeof(x).sharding) 以便散布操作需要集合通信。

  • 弃用

    • jax.dlpack.SUPPORTED_DTYPES 已被弃用;请使用新的 jax.dlpack.is_supported_dtype() 函数。

    • jax.scipy.special.sph_harm() 已被弃用,类似于 SciPy 中的类似弃用;请改用 jax.scipy.special.sph_harm_y()

    • jax.interpreters.xla 中,先前已弃用的符号 abstractifypytype_aval_mappings 已被移除。

    • jax.interpreters.xla.canonicalize_dtype() 已被弃用。要规范化 dtype,请优先使用 jax.dtypes.canonicalize_dtype()。要检查对象是否为有效的 jax 输入,请优先使用 jax.core.valid_jaxtype()

    • jax.core 中,先前已弃用的符号 AxisNameConcretizationTypeErroraxis_framecall_pclosed_call_pget_typetrace_state_cleantypematchtypecheck 已被移除。

    • jax.lib.xla_client 中,先前已弃用的符号 DeviceAssignmentget_topology_for_devicesmlir_api_version 已被移除。

    • jax.extend.ffi 已被移除,此前已在 v0.5.0 中弃用。请使用 jax.ffi 代替。

    • jax.lib.xla_bridge.get_compile_options() 已被弃用,并由 jax.extend.backend.get_compile_options() 替换。

JAX 0.6.2 (2025年6月17日)#

  • 新功能

  • 更改

    • 最低 NumPy 版本为 1.26,最低 SciPy 版本为 1.12。

JAX 0.6.1 (2025年5月21日)#

  • 新功能

  • 更改

    • CUDA 包依赖项的版本检查被重新启用,之前在一个之前的版本中被意外禁用。

    • JAX nightly 包现在已发布到 artifact registry。要安装这些包,请参阅JAX 安装指南

    • jax.sharding.PartitionSpec 不再继承自元组。

    • jax.ShapeDtypeStruct 现在是不可变的。请使用 .update 方法更新您的 ShapeDtypeStruct,而不是进行原地更新。

  • 弃用

    • jax.custom_derivatives.custom_jvp_call_jaxpr_p 已被弃用,并将在 JAX v0.7.0 中移除。

JAX 0.6.0 (2025年4月16日)#

  • 重大更改

    • jax.numpy.array() 不再接受 None。此行为自 2023 年 11 月起已被弃用,现已移除。

    • 移除了 config.jax_data_dependent_tracing_fallback 配置选项,该选项是在 v0.4.36 中临时添加的,允许用户选择退出新的“无堆栈”跟踪机制。

    • 移除了 config.jax_eager_pmap 配置选项。

    • 禁止对 jax.jit 的结果调用 lowertrace AOT API,如果后续应用了其他包装器。之前可以这样做,但会默默忽略包装器。解决方法是将 jax.jit 作为最后一个包装器应用,jax.pmap 同理。请参阅#27873

    • jaxcuda12_pip 附加包已移除;请改为使用 pip install jax[cuda12]

  • 更改

    • 最低 CuDNN 版本为 v9.8。

    • JAX 现在使用 CUDA 12.8 构建。所有 CUDA 12.1 或更高版本仍然受支持。

    • JAX 包附加包现已更新为使用破折号而不是下划线,以符合 PEP 685。例如,如果您之前使用 pip install jax[cuda12_local] 来安装 JAX,请改用 pip install jax[cuda12-local]

    • jax.jit() 现在要求 fun 按位置传递,其他参数按关键字传递。否则将在 v0.6.X 中引发 DeprecationWarning,并在 v0.7.X 中开始引发错误。

  • 弃用

    • jax.tree_util.build_tree() 已被弃用。请改用 jax.tree.unflatten()

    • 使用 XLA 的 FFI 实现了一系列 CPU 和 GPU 设备的主机回调处理程序,并移除了使用 XLA 自定义调用的现有 CPU/GPU 处理程序。

    • jax.lib.xla_extension 中的所有 API 现在都已弃用。

    • jax.interpreters.mlir.hlojax.interpreters.mlir.func_dialect,这些是意外导出的,已被移除。如果需要,可以从 jax.extend.mlir 获取。

    • jax.interpreters.mlir.custom_call 已被弃用。应使用 jax.ffi 提供的 API。

    • 使用内联参数的 jax.ffi.ffi_call() 的已弃用用法不再受支持。 ffi_call() 现在无条件返回一个可调用对象。

    • jax.lib.xla_client 中的以下导出已弃用:get_topology_for_devicesheap_profilemlir_api_versionClientCompileOptionsDeviceAssignmentFrameHloShardingOpShardingTraceback

    • jax.util 中的以下内部 API 已弃用:HashableFunctionas_hashable_functioncachesafe_mapsafe_zipsplit_dictsplit_listsplit_list_checkedsplit_mergesubvalstoposortunzip2wrap_namewraps

    • jax.dlpack.to_dlpack 已被弃用。通常可以将 JAX Array 直接传递给另一个框架的 from_dlpack 函数。如果您需要 to_dlpack 的功能,请使用数组的 __dlpack__ 属性。

    • jax.lax.infeedjax.lax.infeed_pjax.lax.outfeedjax.lax.outfeed_p 已被弃用,并将在 JAX v0.7.0 中移除。

    • 已移除几个先前已弃用的 API,包括

      • 来自 jax.lib.xla_clientArrayImplFftTypePaddingTypePrimitiveTypeXlaBuilderdtype_to_etypeopsregister_custom_call_targetshape_from_pyvalShapeXlaComputation

      • 来自 jax.lib.xla_extensionArrayImplXlaRuntimeError

      • 来自 jaxjax.treedef_is_leafjax.tree_flattenjax.tree_mapjax.tree_leavesjax.tree_structurejax.tree_transposejax.tree_unflatten。替换项可以在 jax.treejax.tree_util 中找到。

      • 来自 jax.coreAxisSizeClosedJaxprEvalTraceInDBIdxInputTypeJaxprJaxprEqnLiteralMapPrimitiveOpaqueTraceStateOutDBIdxPrimitiveTokenTRACER_LEAK_DEBUGGER_WARNINGVarconcrete_avaldedup_referentsescaped_tracer_errorextend_axis_env_ndfull_lowerget_referentjaxpr_as_funjoin_effectslattice_joinleaked_tracer_errormaybe_find_leaked_tracersraise_to_shapedraise_to_shaped_mappingsreset_trace_statestr_eqn_compactsubstitute_vars_in_output_tytypecompatused_axis_names_jaxpr。大多数没有公共替换项,但少数可以在 jax.extend.core 中找到。

      • vectorized 参数到 pure_callback()ffi_call()。请改用 vmap_method 参数。

jax 0.5.3 (2025年3月19日)#

jax 0.5.2 (2025年3月4日)#

0.5.1 的补丁版本

  • 错误修复

    • 修复了 TPU 指标日志记录和 tpu-info 的 bug,这些在 0.5.1 中是损坏的

jax 0.5.1 (2025年2月24日)#

  • 重大更改

    • jit 跟踪缓存现在以输入的 NamedShardings 为键。之前,跟踪缓存根本不包含分片信息(尽管后续的 jit 缓存如降低和编译缓存包含),因此不同类型的两个等效分片不会重新跟踪,但现在它们会。例如

      @jax.jit
      def f(x):
        return x
      
      # inp1.sharding is of type SingleDeviceSharding
      inp1 = jnp.arange(8)
      f(inp1)
      
      mesh = jax.make_mesh((1,), ('x',))
      # inp2.sharding is of type NamedSharding
      inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))
      f(inp2)  # tracing cache miss
      

      在上面的示例中,调用 f(inp1) 然后调用 f(inp2) 将导致跟踪缓存未命中,因为在跟踪过程中抽象值上的分片已更改。

  • 新功能

  • 更改

    • JAX_CPU_COLLECTIVES_IMPLEMENTATIONJAX_NUM_CPU_DEVICES 现在可用作环境变量。在此之前,它们只能通过 jax.config 或标志指定。

    • JAX_CPU_COLLECTIVES_IMPLEMENTATION 现在默认为 'gloo',这意味着多进程 CPU 通信开箱即用。

    • TPU 附加包 jax[tpu] 不再依赖于 libtpu-nightly 包。如果您的机器上存在此包,可以安全地移除;JAX 现在使用 libtpu 代替。

  • 弃用

    • 内部函数 linear_util.wrap_init 和构造函数 core.Jaxpr 现在必须接受非空的 core.DebugInfo 关键字参数。在有限的时间内,如果使用 jax.extend.linear_util.wrap_init 而不带调试信息,则会打印 DeprecationWarning。这导致了许多其他内部函数需要调试信息。此更改不影响公共 API。有关更多详细信息,请参阅 https://github.com/jax-ml/jax/issues/26480。

    • jax.numpy.ndim()jax.numpy.shape()jax.numpy.size() 中,现在已弃用非类数组输入(如列表、元组等)。

  • 错误修复

    • TPU 运行时启动和关闭时间在 TPU v5e 及更新版本上应有显著改善(从约 17 秒到约 8 秒)。如果尚未设置,您可能需要在 VM 映像中启用透明大页(sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled')。我们希望在未来的版本中进一步改进这一点。

    • 如果未设置 JAX_COMPILATION_CACHE_MAX_SIZE 或设置为 -1(即未启用 LRU 逐出策略),则持久化编译缓存不再写入访问时间文件。这应该能提高使用大容量网络存储的缓存的性能。

jax 0.5.0 (2025年1月17日)#

自此版本起,JAX 现在使用基于工作量的版本控制。由于此版本对 PRNG 密钥语义进行了可能需要用户更新其代码的重大更改,因此我们将 JAX 的“meso”版本号升级以表示这一点。

  • 重大更改

    • 默认启用 jax_threefry_partitionable(请参阅更新说明)。

    • 此版本放弃对 Mac x86 轮子的支持。Mac ARM 当然仍然受支持。有关最近的讨论,请参阅 https://github.com/jax-ml/jax/discussions/22936。

      两个关键因素促成了这一决定:

      • Mac x86 构建(仅)存在许多测试失败和崩溃。我们宁愿不发布任何版本,也不发布一个损坏的版本。

      • Mac x86 硬件已停产,目前开发者难以获得。因此,即使我们想修复这类问题,也很难做到。

      如果我们对 Mac x86 的支持社区愿意提供帮助,我们愿意重新添加对 Mac x86 的支持:特别是,在我们再次发布版本之前,我们需要 JAX 测试套件在 Mac x86 上干净地通过。

  • 更改

    • 最低 NumPy 版本现在是 1.25。NumPy 1.25 将保持最低支持版本,直到 2025 年 6 月。

    • 最低 SciPy 版本现在是 1.11。SciPy 1.11 将保持最低支持版本,直到 2025 年 6 月。

    • jax.numpy.einsum() 现在默认使用 optimize='auto' 而不是 optimize='optimal'。这避免了在参数很多的情况下(#25214)指数级增长的跟踪时间。

    • jax.numpy.linalg.solve() 不再支持右侧的批处理一维参数。要在这些情况下恢复之前的行为,请使用 solve(a, b[..., None]).squeeze(-1)

  • 新功能

  • 弃用

    • jax.interpreters.xla 中,abstractifypytype_aval_mappings 现在已弃用,已被 jax.core 中同名符号替换。

    • jax.scipy.special.lpmn()jax.scipy.special.lpmn_values() 已被弃用,遵循其在 SciPy v1.15.0 中的弃用。没有计划用新 API 替换这些已弃用的函数。

    • jax.extend.ffi 子模块已移至 jax.ffi,并且之前的导入路径已被弃用。

  • 删除

    • jax_enable_memories 标志已被删除,并且该标志的行为默认开启。

    • jax.lib.xla_client 中,先前已弃用的 DeviceXlaRuntimeError 符号已移除;请分别使用 jax.Devicejax.errors.JaxRuntimeError

    • jax.experimental.array_api 模块已移除,此前已在 JAX v0.4.32 中弃用。自该版本以来,jax.numpy 直接支持 Array API。

jax 0.4.38 (2024年12月17日)#

  • 重大更改

    • XlaExecutable.cost_analysis 现在返回一个 dict[str, float](而不是一个单元素的 list[dict[str, float]])。

  • 更改

    • jax.tree.flatten_with_pathjax.tree.map_with_path 已添加为相应 tree_util 函数的快捷方式。

  • 弃用

    • 内部 jax.core 命名空间中的一些 API 已被弃用。大多数是无操作,很少使用,或者可以被 jax.extend.core 中同名 API 替换;有关这些半公开扩展的兼容性保证信息,请参阅 jax.extend 的文档。

    • 已移除几个先前已弃用的 API,包括

      • 来自 jax.corecheck_eqncheck_typecheck_valid_jaxtypenon_negative_dim

      • 来自 jax.lib.xla_bridgexla_clientdefault_backend

      • 来自 jax.lib.xla_client_xlabfloat16

      • 来自 jax.numpyround_

  • 新功能

jax 0.4.37 (2024年12月9日)#

这是 jax 0.4.36 的补丁版本。仅发布了“jax”此版本。

  • 错误修复

    • 修复了一个 bug,在该 bug 中 jit 会在参数命名为 f 时报错(#25329)。

    • 修复了一个 bug,该 bug 在用户使用不同辅助数据注册 pytree 节点类用于 flatten 和 flatten_with_path 时,会在 jax.lax.while_loop() 中抛出 index out of range 错误。

    • 固定了一个新的 libtpu 版本 (0.0.6),该版本修复了 TPU v6e 上的编译器 bug。

jax 0.4.36 (2024年12月5日)#

  • 重大更改

    • 此版本实现了“stackless”,这是 JAX 跟踪机制的一个内部更改。我们将跟踪分派完全改为上下文的函数,而不是上下文和数据的组合。这使我们能够删除大量用于管理数据相关跟踪的机制:级别、子级别、post_process_callnew_base_maincustom_bind 等等。

      如果您确实使用了 JAX 内部,那么您可能需要更新您的代码(有关如何操作的线索,请参阅 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f)。也可能存在 JAX 库版本不兼容的问题。如果您发现此更改破坏了不使用 JAX 内部的代码,请尝试使用 config.jax_data_dependent_tracing_fallback 标志作为解决方法,如果您需要帮助更新代码,请提交 bug。

    • 使用 native_serialization=Falseenable_xla=Falsejax.experimental.jax2tf.convert() 自 2024 年 7 月(JAX 版本 0.4.31)起已被弃用。现在我们移除了对这些用例的支持。带有原生序列化的 jax2tf 仍将受支持。

    • jax.interpreters.xla 中,xbxcxe 符号已移除,此前已在 JAX v0.4.31 中弃用。而是使用 xb = jax.lib.xla_bridgexc = jax.lib.xla_clientxe = jax.lib.xla_extension

    • 已弃用的模块 jax.experimental.export 已被移除。它已在 JAX v0.4.30 中被 jax.export 替换。有关迁移到新 API 的信息,请参阅迁移指南

    • initial 参数到 jax.nn.softmax()jax.nn.log_softmax() 已被移除,此前已在 v0.4.27 中弃用。

    • 对类型化 PRNG 密钥(即 jax.random.key() 生成的密钥)调用 np.asarray 现在会引发错误。之前,这会返回一个标量对象数组。

    • 已移除 jax.export 中的以下已弃用方法和函数

      • jax.export.DisabledSafetyCheck.shape_assertions:它已经没有效果了。

      • jax.export.Exported.lowering_platforms:使用 platforms

      • jax.export.Exported.mlir_module_serialization_version:使用 calling_convention_version

      • jax.export.Exported.uses_shape_polymorphism:使用 uses_global_constants

      • jax.export.export()lowering_platforms 关键字参数:改用 platforms

    • jax.export.symbolic_args_specs() 中的关键字参数 symbolic_scopesymbolic_constraints 已被移除。它们已于 2024 年 6 月弃用。请改用 scopeconstraints

    • 跟踪器的哈希处理,该处理自版本 0.4.30 起已被弃用,现在会导致 TypeError

    • 重构:JAX 构建 CLI (build/build.py) 现在使用子命令结构,并取代了之前的 build.py 用法。运行 python build/build.py --help 获取更多详细信息。新子命令选项的简要概述

      • build:构建 JAX 轮子包。例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt

      • requirements_update:更新 requirements_lock.txt 文件。

    • jax.scipy.linalg.toeplitz() 现在对多维输入进行隐式批处理。要恢复之前的行为,您可以在函数输入上调用 jax.numpy.ravel()

    • jax.scipy.special.gamma()jax.scipy.special.gammasgn() 现在对于负整数输入返回 NaN,以匹配 SciPy 的行为,源自 https://github.com/scipy/scipy/pull/21827。

    • jax.clear_backends 已被移除,此前已在 v0.4.26 中弃用。

    • 我们移除了自定义调用“__gpu$xla.gpu.triton”,将其从我们保证导出稳定性的自定义调用列表中移除。这是因为此自定义调用依赖于 Triton IR,而 Triton IR 不保证稳定。如果您需要导出使用此自定义调用的代码,可以使用 disabled_checks 参数。有关更多详细信息,请参阅文档

  • 新功能

  • 错误修复

    • 修复了一个 bug,该 bug 导致 LU 和 QR 分解的 GPU 实现会导致批次大小接近 int32 最大值时发生索引溢出。有关更多详细信息,请参阅 #24843

  • 弃用

    • jax.lib.xla_extension.ArrayImpljax.lib.xla_client.ArrayImpl 已弃用;请改用 jax.Array

    • jax.lib.xla_extension.XlaRuntimeError 已弃用;请改用 jax.errors.JaxRuntimeError

jax 0.4.35 (2024 年 10 月 22 日)#

  • 重大更改

    • jax.numpy.isscalar() 现在对于任何零维度的类数组对象都返回 True。以前,它只对具有弱 dtype 的零维类数组对象返回 True。

    • jax.experimental.host_callback 自 2024 年 3 月(JAX 版本 0.4.26)起已弃用。现在我们已将其移除。有关替代方案的讨论,请参阅 #20385

  • 更改

    • jax.lax.FftType 被引入为 FFT 操作枚举的公共名称。半公共 API jax.lib.xla_client.FftType 已弃用。

    • TPU:JAX 现在从 libtpu 包而不是 libtpu-nightly 安装 TPU 支持。在接下来的几个版本中,JAX 将将 libtpu-nightlylibtpu 的版本固定为空,以简化过渡;此依赖项将在 2025 年第一季度移除。

  • 弃用

    • 半公共 API jax.lib.xla_client.PaddingType 已弃用。没有 JAX API 使用此类型,因此没有替代品。

    • vmap 下,jax.pure_callback()jax.extend.ffi.ffi_call() 的默认行为已弃用,这些函数的 vectorized 参数也已弃用。应改用 vmap_method 参数以获得更明确的行为。有关更多详细信息,请参阅 #23881 中的讨论。

    • 半公共 API jax.lib.xla_client.register_custom_call_target 已弃用。请改用 JAX FFI。

    • 半公共 API jax.lib.xla_client.dtype_to_etypejax.lib.xla_client.opsjax.lib.xla_client.shape_from_pyvaljax.lib.xla_client.PrimitiveTypejax.lib.xla_client.Shapejax.lib.xla_client.XlaBuilderjax.lib.xla_client.XlaComputation 已弃用。请改用 StableHLO。

jax 0.4.34 (2024 年 10 月 4 日)#

  • 新功能

    • 此版本包含 Python 3.13 的 wheel。尚不支持 free-threading 模式。

    • 已添加 jax.errors.JaxRuntimeError 作为以前私有的 XlaRuntimeError 类型的公共别名。

  • 重大更改

    • 标志 jax_pmap_no_rank_reduction 默认设置为 True

      • pmap 结果上的 array[0] 现在会引入一个 reshape(请改用 array[0:1])。

      • 每个分片的形状(可通过 jax_array.addressable_shards 或 jax_array.addressable_data(0) 访问)现在具有一个前导 (1, …)。相应地更新直接访问分片的代码。每个分片形状的秩现在与全局形状的秩匹配,这与 jit 的行为相同。这避免了在将 pmap 的结果传递到 jit 时进行昂贵的 reshape。

    • jax.experimental.host_callback 自 2024 年 3 月(JAX 版本 0.4.26)起已弃用。现在我们将 --jax_host_callback_legacy 配置值的默认值设置为 True,这意味着如果您的代码使用 jax.experimental.host_callback API,那些 API 调用将基于新的 jax.experimental.io_callback API 实现。如果这破坏了您的代码,在非常短的时间内,您可以将 --jax_host_callback_legacy 设置为 True。很快我们将移除该配置选项,因此您应该改用新的 JAX 回调 API。有关讨论,请参阅 #20385

  • 弃用

    • jax.numpy.trim_zeros() 中,非类数组参数或 ndim != 1 的类数组参数现在已弃用,将来会导致错误。

    • 内部 pretty-printing 工具 jax.core.pp_* 已移除,之前在 JAX v0.4.30 中已弃用。

    • jax.lib.xla_client.Device 已弃用;请改用 jax.Device

    • jax.lib.xla_client.XlaRuntimeError 已弃用。请改用 jax.errors.JaxRuntimeError

  • 删除

    • jax.xla_computation 已删除。距离其在 0.4.30 JAX 版本中弃用已有 3 个月。请使用 AOT API 来获取与 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以被替换为 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 您还可以使用 jax.stages.Lowered.out_info 属性来获取输出信息(如树结构、形状和 dtype)。

      • 对于跨后端降低,您可以将 jax.xla_computation(fn, backend='tpu')(*args, **kwargs) 替换为 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')

    • jax.ShapeDtypeStruct 不再接受 named_shape 参数。该参数仅由已移除的 xmap 使用(在 0.4.31 版本中移除)。

    • jax.tree.map(f, None, non-None),之前会发出 DeprecationWarning,现在在未来的 jax 版本中会引发错误。None 仅是其自身的树前缀。要保留当前行为,您可以要求 jax.tree.mapNone 视为叶子值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

    • jax.sharding.XLACompatibleSharding 已移除。请使用 jax.sharding.Sharding

  • 错误修复

    • 修复了一个 bug,该 bug 导致 jax.numpy.cumsum() 在提供非布尔输入并指定 dtype=bool 时产生不正确的结果。

    • 编辑 jax.numpy.ldexp() 的实现以获得正确的梯度。

jax 0.4.33 (2024 年 9 月 16 日)#

这是在 jax 0.4.32 之上的一个补丁版本,修复了该版本中的两个 bug。

在 JAX 0.4.32 固定的 libtpu 版本中发现了一个仅限 TPU 的数据损坏 bug,该 bug 仅在同一作业中有多个 TPU 切片时出现,例如,在多个 v5e 切片上训练时。此版本通过固定一个特定版本的 libtpu 来修复此问题。

此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。

jax 0.4.32 (2024 年 9 月 11 日)#

注意:此版本因 TPU 数据损坏 bug 已从 PyPi 中移除。有关更多详细信息,请参阅 0.4.33 版本说明。

  • 新功能

    • 添加了 jax.extend.ffi.ffi_call()jax.extend.ffi.ffi_lowering(),以支持使用新的 Foreign function interface (FFI) 从 JAX 与自定义 C++ 和 CUDA 代码进行交互。

  • 更改

    • 标志 jax_enable_memories 默认设置为 True

    • jax.numpy 现在支持 Python Array API 标准的 v2023.12 版本。有关更多信息,请参阅 Python Array API standard

    • CPU 后端上的计算现在可以在更多情况下异步分派。以前,非并行计算始终是同步分派的。您可以通过设置 jax.config.update('jax_cpu_enable_async_dispatch', False) 来恢复旧行为。

    • 添加了新的 jax.process_indices() 函数,以替换 JAX v0.2.13 中已弃用的 jax.host_ids() 函数。

    • 为了与 numpy.fabs 的行为保持一致,jax.numpy.fabs 已修改为不再支持 complex dtypes

    • jax.tree_util.register_dataclass 现在会检查 data_fieldsmeta_fields 是否包含所有 init=True 的 dataclass 字段,并且仅包含这些字段,如果 nodetype 是 dataclass。

    • 一些 jax.numpy 函数现在具有完整的 ufunc 接口,包括 addmultiplybitwise_andbitwise_orbitwise_xorlogical_andlogical_andlogical_and。在未来的版本中,我们计划将这些扩展到其他 ufunc。

    • 添加了 jax.lax.optimization_barrier(),它允许用户阻止编译器优化,例如公共子表达式消除和控制调度。

  • 重大更改

    • MHLO MLIR 方言 (jax.extend.mlir.mhlo) 已移除。请改用 stablehlo 方言。

  • 弃用

    • 不再允许对 jax.numpy.clip()jax.numpy.hypot() 输入复数,之前在 JAX v0.4.27 版本中已弃用。

    • 弃用了以下 API

      • jax.lib.xla_bridge.xla_client:请直接使用 jax.lib.xla_client

      • jax.lib.xla_bridge.get_backend:请使用 jax.extend.backend.get_backend()

      • jax.lib.xla_bridge.default_backend:请使用 jax.extend.backend.default_backend()

    • jax.experimental.array_api 模块已弃用,不再需要导入它即可使用 Array API。jax.numpy 直接支持 array API;有关更多信息,请参阅 Python Array API standard

    • 内部实用工具 jax.core.check_eqnjax.core.check_typejax.core.check_valid_jaxtype 现在已弃用,并将移除。

    • jax.numpy.round_ 已弃用,遵循 NumPy 2.0 中移除相应 API 的情况。请改用 jax.numpy.round()

    • 将 DLPack capsule 传递给 jax.dlpack.from_dlpack() 已弃用。jax.dlpack.from_dlpack() 的参数应该是实现了 __dlpack__ 协议的另一个框架的数组。

jaxlib 0.4.32 (2024 年 9 月 11 日)#

注意:此版本因 TPU 数据损坏 bug 已从 PyPi 中移除。有关更多详细信息,请参阅 0.4.33 版本说明。

  • 重大更改

    • 此 jaxlib 版本切换到了新的 CPU 后端版本,该版本编译速度更快,并且能更好地利用并行性。如果您遇到由于此更改导致的问题,可以通过设置环境变量 XLA_FLAGS=--xla_cpu_use_thunk_runtime=false 来暂时启用旧的 CPU 后端。如果您需要这样做,请提交一个 JAX bug 并提供重现说明。

    • 已添加 hermetic CUDA 支持。Hermetic CUDA 使用特定的可下载 CUDA 版本,而不是用户本地安装的 CUDA。Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后将 CUDA 库和工具用作各种 Bazel 目标中的依赖项。这使得 JAX 及其支持的 CUDA 版本的构建更加可复现。

  • 更改

    • 已添加 SparseCore 性能分析。

      • JAX 现在支持在 TPUv5p 芯片上对 SparseCore 进行性能分析。这些跟踪将在 Tensorboard Profiler 的 TraceViewer 中可见。

jax 0.4.31 (2024 年 7 月 29 日)#

  • 删除

    • xmap 已删除。请使用 shard_map() 作为替代。

  • 更改

    • 最低 CuDNN 版本为 v9.1。这在之前的版本中也适用,但我们现在正式声明此版本约束。

    • 最低 Python 版本现在是 3.10。3.10 将在 2025 年 7 月之前保持最低支持版本。

    • 最低 NumPy 版本现在是 1.24。NumPy 1.24 将在 2024 年 12 月之前保持最低支持版本。

    • 最低 SciPy 版本现在是 1.10。SciPy 1.10 将在 2025 年 1 月之前保持最低支持版本。

    • jax.numpy.ceil()jax.numpy.floor()jax.numpy.trunc() 现在返回与输入相同 dtype 的输出,即不再将整数或布尔输入提升为浮点数。

    • libdevice.10.bc 不再与 CUDA wheels 一起捆绑。它必须作为本地 CUDA 安装的一部分,或通过 NVIDIA 的 CUDA pip wheels 安装。

    • jax.experimental.pallas.BlockSpec 现在期望 block_shapeindex_map 之前传递。旧的参数顺序已弃用,将在未来的版本中移除。

    • 更新了 GPU 设备的 repr,使其与 TPU/CPU 更一致。例如,cuda(id=0) 现在将是 CudaDevice(id=0)

    • jax.Array 中添加了 device 属性和 to_device 方法,这是 JAX 的 Array API 支持的一部分。

  • 弃用

    • 移除了许多先前已弃用的与多态形状相关的内部 API。从 jax.core:移除了 canonicalize_shapedimension_as_valuedefinitely_equalsymbolic_equal_dim

    • HLO 降低规则不应再将单例 ir.Values 包装在元组中。相反,返回未包装的单例 ir.Values。对包装值的支持将在未来的 JAX 版本中移除。

    • 带有 native_serialization=Falseenable_xla=Falsejax.experimental.jax2tf.convert() 已弃用,并且此支持将在未来的版本中移除。自 JAX 0.4.16(2023 年 9 月)以来,原生序列化一直是默认值。

    • 已移除先前已弃用的函数 jax.random.shuffle;请改用 jax.random.permutation 并设置 independent=True

jaxlib 0.4.31 (2024 年 7 月 29 日)#

  • 错误修复

    • 修复了一个 bug,该 bug 导致 jit 的快速路径错误处理了 jit 的负 static_argnums。

    • 修复了一个 bug,该 bug 导致奇异矩阵批次的三角求解产生无意义的有限值,而不是 inf 或 nan (#3589, #15429)。

jax 0.4.30 (2024 年 6 月 18 日)#

  • 更改

    • JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已升至 0.4.0,但在此版本中已回滚,以便同时使用 TensorFlow 和 JAX 的用户有更多时间迁移到更新的 TensorFlow 版本。

    • jax.experimental.mesh_utils 现在可以为 TPU v5e 创建一个高效的 mesh。

    • jax 现在直接依赖于 jaxlib。此更改由 CUDA 插件切换启用:不再存在多个 jaxlib 变体。您可以安装仅 CPU 的 jax,方法是 pip install jax,无需额外选项。

    • 添加了一个用于导出和序列化 JAX 函数的 API。这以前存在于 jax.experimental.export(目前正在弃用),现在将位于 jax.export 中。请参阅 文档

  • 弃用

    • 内部 pretty-printing 工具 jax.core.pp_* 已弃用,将在未来的版本中移除。

    • tracer 的哈希处理已弃用,并且将在未来的 JAX 版本中导致 TypeError。以前是这种情况,但在最近几个 JAX 版本中出现了意外的回归。

    • jax.experimental.export 已弃用。请改用 jax.export。请参阅 迁移指南

    • 在大多数情况下,将数组作为 dtype 传递现在已弃用;例如,对于数组 xyx.astype(y) 将引发警告。要消除警告,请使用 x.astype(y.dtype)

    • jax.xla_computation 已弃用,将在未来的版本中移除。请使用 AOT API 来获取与 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以被替换为 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 您还可以使用 jax.stages.Lowered.out_info 属性来获取输出信息(如树结构、形状和 dtype)。

      • 对于跨后端降低,您可以将 jax.xla_computation(fn, backend='tpu')(*args, **kwargs) 替换为 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')

jaxlib 0.4.30 (2024 年 6 月 18 日)#

  • 已放弃对单体 CUDA jaxlibs 的支持。您必须使用基于插件的安装(pip install jax[cuda12]pip install jax[cuda12_local])。

jax 0.4.29 (2024 年 6 月 10 日)#

  • 更改

    • 我们预计这将是最后一个支持单体 CUDA jaxlib 的 JAX 和 jaxlib 版本。未来的版本将使用 CUDA 插件 jaxlib(例如 pip install jax[cuda12])。

    • JAX 现在需要 ml_dtypes 版本 0.4.0 或更高版本。

    • 移除了对 jax.experimental.export API 旧用法的向后兼容支持。现在无法使用 from jax.experimental.export import export,而是应使用 from jax.experimental import export。移除的功能自 0.4.24 版本起已弃用。

    • jax.tree.all() & jax.tree_util.tree_all() 中添加了 is_leaf 参数。

  • 弃用

    • jax.sharding.XLACompatibleSharding 已弃用。请使用 jax.sharding.Sharding

    • jax.experimental.Exported.in_shardings 已重命名为 jax.experimental.Exported.in_shardings_hlo。对于 out_shardings 也是如此。旧名称将在 3 个月后移除。

    • 移除了许多先前已弃用的 API

      • 来自 jax.corenon_negative_dimDimSizeShape

      • 来自 jax.laxtie_in

      • 来自 jax.nnnormalize

      • 来自 jax.interpreters.xlabackend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXlaOp

    • jax.numpy.linalg.matrix_rank()tol 参数正在弃用,并将很快移除。请改用 rtol

    • jax.numpy.linalg.pinv()rcond 参数正在弃用,并将很快移除。请改用 rtol

    • 已移除已弃用的 jax.config 子模块。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。

    • jax.random API 不再接受批次密钥,以前有些是意外接受的。今后,我们建议在这种情况下显式使用 jax.vmap()

    • jax.scipy.special.beta() 中,xy 参数已重命名为 ab,以与其他 beta API 保持一致。

  • 新功能

    • 添加了 jax.experimental.Exported.in_shardings_jax(),用于构造可与 JAX API 一起使用的分片,这些分片来自存储在 Exported 对象中的 HloShardings。

jaxlib 0.4.29 (2024 年 6 月 10 日)#

  • 错误修复

    • 修复了一个 bug,该 bug 导致 XLA 错误地对某些连接操作进行分片,表现为累积约简的输出不正确 (#21403)。

    • 修复了一个 bug,该 bug 导致 XLA:CPU 错误地编译了某些 matmul fusion(https://github.com/openxla/xla/pull/13301)。

    • 修复了 GPU 上的编译器崩溃问题(https://github.com/jax-ml/jax/issues/21396)。

  • 弃用

    • jax.tree.map(f, None, non-None) 现在会发出 DeprecationWarning,并在未来的 jax 版本中引发错误。None 仅是其自身的树前缀。要保留当前行为,您可以要求 jax.tree.mapNone 视为叶子值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

jax 0.4.28 (2024 年 5 月 9 日)#

  • 错误修复

    • 恢复了 make_jaxpr 的一个更改,该更改曾破坏 Equinox (#21116)。

  • 弃用和移除

  • 更改

    • 此版本最低 jaxlib 版本为 0.4.27。

jaxlib 0.4.28 (2024 年 5 月 9 日)#

  • 错误修复

    • 修复了 Python 3.10 或更早版本中 Array 和 JIT Python 对象类型名称中的内存损坏 bug。

    • 修复了 CUDA 12.4 下的警告 '+ptx84' is not a recognized feature for this target

    • 修复了 CPU 上的慢速编译问题。

  • 更改

    • Windows 构建现在使用 Clang 而不是 MSVC 构建。

jax 0.4.27 (2024 年 5 月 7 日)#

  • 新功能

    • 添加了 jax.numpy.unstack()jax.numpy.cumulative_sum(),继它们在 array API 2023 标准中添加之后,该标准将很快被 NumPy 采纳。

    • 添加了一个新的配置选项 jax_cpu_collectives_implementation,用于选择 CPU 后端使用的跨进程集合操作的实现。可用的选项有 'none'(默认)、'gloo''mpi'(需要 jaxlib 0.4.26)。如果设置为 'none',则禁用跨进程集合操作。

  • 更改

    • jax.pure_callback()jax.experimental.io_callback()jax.debug.callback() 现在使用 jax.Array 而不是 np.ndarray。您可以通过在将参数传递给回调之前使用 jax.tree.map(np.asarray, args) 来恢复旧行为。

    • complex_arr.astype(bool) 现在遵循与 NumPy 相同的语义,对于 complex_arr 等于 0 + 0j 的情况返回 False,否则返回 True。

    • core.Token 现在是一个非平凡类,它封装了一个 jax.Array。它可以被创建并穿梭于计算之间以建立依赖关系。单例对象 core.token 已被移除,用户现在应该创建并使用新的 core.Token 对象。

    • 在 GPU 上,Threefry PRNG 实现默认不再降低到内核调用。这个选择可以提高运行时内存使用率,但会增加编译时间。可以恢复到旧行为(生成内核调用)的选项是 jax.config.update('jax_threefry_gpu_kernel_lowering', True)。如果新默认设置导致问题,请提交 bug。否则,我们打算在未来的版本中移除此标志。

  • 弃用和移除

    • Pallas 现在仅使用 XLA 来编译 GPU 上的 Triton 内核。旧的通过 Triton Python API 进行降低的通道已移除,并且 JAX_TRITON_COMPILE_VIA_XLA 环境变量不再有任何作用。

    • jax.numpy.clip() 的参数签名已更改:aa_mina_max 已弃用,取而代之的是 x(仅位置参数)、minmax(#20550)。

    • JAX 数组的 device() 方法已被移除,自 JAX v0.4.21 起已弃用。请改用 arr.devices()

    • jax.nn.softmax()jax.nn.log_softmax()initial 参数已弃用;空输入现在无需设置此参数即可使用 softmax。

    • jax.jit() 中,传递无效的 static_argnumsstatic_argnames 现在会导致错误而不是警告。

    • 最低 jaxlib 版本现在是 0.4.23。

    • jax.numpy.hypot() 函数现在在将复数值输入传递给它时会发出弃用警告。当弃用完成后,这将引发一个错误。

    • jax.numpy.nonzero()jax.numpy.where() 和相关函数的标量参数现在会引发错误,这遵循了 NumPy 中类似的更改。

    • 配置选项 jax_cpu_enable_gloo_collectives 已弃用。请改用 jax.config.update('jax_cpu_collectives_implementation', 'gloo')

    • jax.Array.device_bufferjax.Array.device_buffers 方法已被移除,之前在 JAX v0.4.22 中已弃用。请改用 jax.Array.addressable_shardsjax.Array.addressable_data()

    • jax.numpy.whereconditionxy 参数现在是仅位置参数,遵循 JAX v0.4.21 中对关键字的弃用。

    • jax.lax.linalg 模块中的函数现在必须通过关键字指定非类数组参数。以前,这会引发 DeprecationWarning。

    • 多个 jax.numpy() API 现在需要类数组参数,包括 apply_along_axis()apply_over_axes()inner()outer()cross()kron()lexsort()

  • 错误修复

    • jax.numpy.astype() 现在当 copy=True 时将始终返回副本。以前,当输出数组的 dtype 与输入数组相同的情况下,不会进行复制。这可能会导致内存使用量增加。默认值为 copy=False,以保持向后兼容。

jaxlib 0.4.27 (2024 年 5 月 7 日)#

jax 0.4.26 (2024 年 4 月 3 日)#

  • 新功能

  • 更改

    • 复数值 jax.numpy.geomspace() 现在选择与 NumPy 2.0 一致的对数螺旋分支。

    • lax.rng_bit_generator 的行为,以及 'rbg''unsafe_rbg' PRNG 实现的行为,在 jax.vmap 下发生了变化,导致映射密钥只从批次中的第一个密钥生成随机数(https://github.com/jax-ml/jax/issues/19085)。

    • 文档现在使用 jax.random.key 来构建 PRNG 密钥数组,而不是 jax.random.PRNGKey

  • 弃用和移除

    • jax.tree_map() 已弃用;请改用 jax.tree.map,或者为了与旧版 JAX 兼容,请使用 jax.tree_util.tree_map()

    • jax.clear_backends() 已弃用,因为它不一定与其名称所暗示的一样,并且可能导致意外后果,例如,它不会销毁现有后端并释放相应的拥有资源。如果您只想清理编译缓存,请使用 jax.clear_caches()。为了向后兼容或您确实需要切换/重新初始化默认后端,请使用 jax.extend.backend.clear_backends()

    • jax.experimental.maps 模块和 jax.experimental.maps.xmap 已弃用。请使用 jax.experimental.shard_mapjax.vmapspmd_axis_name 参数来表达 SPMD 设备并行计算。

    • jax.experimental.host_callback 模块已弃用。请改用 新的 JAX 外部回调。已添加 JAX_HOST_CALLBACK_LEGACY 标志以协助迁移到新的回调。有关讨论,请参阅 #20385

    • 将无法转换为 JAX 数组的参数传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 现在会导致异常。

    • 已移除已弃用的标志 jax_parallel_functions_output_gda。此标志已长期弃用且不起作用;其使用是无操作。

    • 已移除先前已弃用的导入 jax.interpreters.ad.configjax.interpreters.ad.source_info_util。请改用 jax.configjax.extend.source_info_util

    • JAX 不再支持旧的序列化版本。版本 9 自 2023 年 10 月 27 日起受支持,并自 2024 年 2 月 1 日起成为默认版本。请参阅 版本描述。此更改可能会破坏设置了低于 9 的特定 JAX 序列化版本的客户端。

jaxlib 0.4.26 (2024 年 4 月 3 日)#

  • 更改

    • JAX 现在仅支持 CUDA 12.1 或更高版本。已停止支持 CUDA 11.8。

    • JAX 现在支持 NumPy 2.0。

jax 0.4.25 (2024 年 2 月 26 日)#

  • 新功能

  • 更改

    • Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将 JAX_TRITON_COMPILE_VIA_XLA 环境变量设置为 "0" 来恢复旧行为。

    • jax.interpreters.xla 中,一些已弃用的 API(已在 v0.4.24 中移除)已在 v0.4.25 中重新添加,包括 backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXLAOp。这些 API 仍被视为已弃用,将在未来提供更好的替代品时再次移除。有关讨论,请参阅 #19816

  • 弃用和移除

    • jax.numpy.linalg.solve() 现在对批次 1D 解(b.ndim > 1)显示弃用警告。将来,这些将被视为批次 2D 解。

    • 将非标量数组转换为 Python 标量现在会引发错误,无论数组大小如何。以前,对于大小为 1 的非标量数组会引发弃用警告。这遵循了 NumPy 中类似的弃用。

    • 已移除先前已弃用的配置 API,遵循标准的 3 个月弃用周期(请参阅 API compatibility)。这些包括

      • 对象 jax.config.config

      • define_*_stateDEFINE_* 方法 jax.config

    • 通过 import jax.config 导入 jax.config 子模块已弃用。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。

    • 最低 jaxlib 版本现在是 0.4.20。

jaxlib 0.4.25 (2024 年 2 月 26 日)#

jax 0.4.24 (2024 年 2 月 6 日)#

  • 更改

    • JAX 降低到 StableHLO 不再依赖于物理设备。如果您的原语在降低规则中包装了 custom_partitioning 或 JAX 回调(即传递给 rule 参数的函数,mlir.register_lowering),则将您的原语添加到 jax._src.dispatch.prim_requires_devices_during_lowering 集中。这是必需的,因为 custom_partitioning 和 JAX 回调在降低期间需要物理设备来创建 Shardings。这是一个临时状态,直到我们可以创建不带物理设备的 Shardings。

    • jax.numpy.argsort()jax.numpy.sort() 现在支持 stabledescending 参数。

    • 形状多态处理的几项更改(用于 jax.experimental.jax2tfjax.experimental.export

      • 更清晰的符号表达式 pretty-printing(#19227)

      • 添加了指定维度变量符号约束的能力。这使得形状多态更具表现力,并提供了一种规避不等式推理限制的方法。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 通过添加符号约束(#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名称。来自不同作用域的符号表达式不能交互,例如,在算术运算中。作用域由 jax.experimental.jax2tf.convert()jax.experimental.export.symbolic_shape()jax.experimental.export.symbolic_args_specs() 引入。作用域 e 的符号表达式可以用 e.scope 读取,并传递给上述函数以指示它们在给定作用域中构建符号表达式。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 简化且更快的等式比较,我们认为两个符号维度相等,如果它们的差值经规范化后还原为 0(#19231;请注意,这可能会导致用户可见的行为更改)

      • 改进了不确定不等式比较的错误消息(#19235)。

      • 已弃用 core.non_negative_dim API(最近引入),并引入了 core.max_dimcore.min_dim(#18953)来表示符号维度的 maxmin。您可以使用 core.max_dim(d, 0) 代替 core.non_negative_dim(d)

      • shape_poly.is_poly_dim 已弃用,改为使用 export.is_symbolic_dim(#19282)。

      • export.args_specs 已弃用,改为使用 export.symbolic_args_specs({jax-issue}#19283)。

      • shape_poly.PolyShapejax2tf.PolyShape 已弃用,请使用字符串指定多态形状(#19284)。

      • JAX 的默认原生序列化版本现为 9。这与 jax.experimental.jax2tfjax.experimental.export 相关。请参阅 版本号描述

    • 重构了 jax.experimental.export 的 API。现在应使用 from jax.experimental import export 而不是 from jax.experimental.export import export。旧的导入方式将在 3 个月的弃用期内继续有效。

    • 添加了 jax.scipy.stats.sem()

    • jax.numpy.unique() 带有 return_inverse = True 时,返回的逆序索引会重塑为输入的维度,这与 NumPy 2.0 中的 numpy.unique() 的更改类似。

    • jax.numpy.sign() 现在对于非零复数输入返回 x / abs(x)。这与 NumPy 2.0 版本中 numpy.sign() 的行为一致。

    • jax.scipy.special.logsumexp() 带有 return_sign=True 时,现在使用 NumPy 2.0 的复数符号约定 x / abs(x)。这与 SciPy v1.13 中 scipy.special.logsumexp() 的行为一致。

    • JAX 现在支持 DLPack 的 bool 类型进行导入和导出。以前 bool 值无法导入,并被导出为整数。

  • 弃用和移除

    • 许多先前已弃用的函数已被移除,遵循标准的 3 个月弃用周期(请参阅 API compatibility)。这包括

      • 来自 jax.coreTracerArrayConversionErrorTracerIntegerConversionErrorUnexpectedTracerErroras_hashable_functioncollectionsdtypeslumapnamedtuplepartialpprefsafe_zipsafe_mapsource_info_utiltotal_orderingtraceback_utiltuple_deletetuple_insertzip

      • 来自 jax.laxdtypesitertoolsnaryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule

      • 子模块 jax.linear_util 及其所有内容。

      • 子模块 jax.prng 及其所有内容。

      • 来自 jax.randomPRNGKeyArrayKeyArraydefault_prng_implthreefry_2x32threefry2x32_keythreefry2x32_prbg_keyunsafe_rbg_key

      • 来自 jax.tree_utilregister_keypathsAttributeKeyPathEntryGetItemKeyPathEntry

      • 来自 jax.interpreters.xlabackend_specific_translations, translations, register_translation, xla_destructure, TranslationRule, TranslationContext, axis_groups, ShapedArray, ConcreteArray, AxisEnv, backend_compile,以及 XLAOp

      • 来自 jax.numpyNINF, NZERO, PZERO, row_stack, issubsctype, trapz,以及 in1d

      • 来自 jax.scipy.linalgtriltriu

    • 先前已弃用的方法 PRNGKeyArray.unsafe_raw_array 已被移除。请改用 jax.random.key_data()

    • bool(empty_array) 现在会引发错误而不是返回 False。这先前会引发弃用警告,并且与 NumPy 中的类似更改一致。

    • 对 mhlo MLIR 方言的支持已弃用。JAX 不再使用 mhlo 方言,而是使用 stablehlo。引用“mhlo”的 API 将在未来被移除。请改用“stablehlo”方言。

    • jax.random:直接将批处理密钥传递给随机数生成函数,例如 bits(), gamma() 等,已被弃用并将发出 FutureWarning。请使用 jax.vmap 进行显式批处理。

    • jax.lax.tie_in() 已弃用:自 JAX v0.2.0 起它一直是无操作。

jaxlib 0.4.24 (2024年2月6日)#

  • 更改

    • JAX 现在支持 CUDA 12.3 和 CUDA 11.8。已停止支持 CUDA 12.2。

    • cost_analysis 现在可以与交叉编译的 Compiled 对象一起工作(即在使用 .lower().compile() 并带有一个拓扑对象时,例如,在非 TPU 计算机上为 Cloud TPU 编译)。

    • 添加了 CUDA 数组接口 导入支持(需要 jax 0.4.25)。

jax 0.4.23 (2023年12月13日)#

jaxlib 0.4.23 (2023年12月13日)#

  • 修复了一个导致 GPU 编译器在编译期间产生详细日志的错误。

jax 0.4.22 (2023年12月13日)#

  • 弃用

    • JAX 数组的 device_bufferdevice_buffers 属性已弃用。显式缓冲区已被更灵活的数组分片接口取代,但仍可以通过这种方式恢复以前的输出。

      • arr.device_buffer 变为 arr.addressable_data(0)

      • arr.device_buffers 变为 [x.data for x in arr.addressable_shards]

jaxlib 0.4.22 (2023年12月13日)#

jax 0.4.21 (2023年12月4日)#

  • 新功能

  • 更改

    • 最低 jaxlib 版本为 0.4.19。

    • 发布的 wheel 现在使用 clang 而非 gcc 构建。

    • 强制要求在调用 jax.distributed.initialize() 之前设备后端尚未初始化。

    • 在 Cloud TPU 环境中自动配置 jax.distributed.initialize() 的参数。

  • 弃用

    • jax.scipy.linalg.solve() 中移除了先前已弃用的 sym_pos 参数。请改用 assume_a='pos'

    • None 传递给 jax.array()jax.asarray()(直接或在列表或元组中)已被弃用,现在会引发 FutureWarning。目前它会被转换为 NaN,将来会引发 TypeError

    • 为了与 numpy.where 保持一致,将 conditionxy 参数以关键字参数形式传递给 jax.numpy.where 已被弃用。

    • 将无法转换为 JAX 数组的参数传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 已被弃用,现在会引发 DeprecationWaning。目前这些函数返回 False,将来会引发异常。

    • JAX 数组的 device() 方法已弃用。根据上下文,它可能被以下任一方法替换:

      • jax.Array.devices() 返回数组使用的所有设备集合。

      • jax.Array.sharding 提供数组使用的分片配置。

jaxlib 0.4.21 (2023年12月4日)#

  • 更改

    • 为了增加对分布式 CPU 的支持,JAX 现在将 CPU 设备与 GPU 和 TPU 设备同等对待,即:

      • jax.devices() 包括分布式作业中存在的所有设备,即使不是当前进程本地的设备。jax.local_devices() 仍然只包括当前进程本地的设备,因此如果 jax.devices() 的更改对您造成了影响,您最可能想要使用 jax.local_devices()

      • CPU 设备现在会在分布式作业中获得一个全局唯一的 ID 号;先前 CPU 设备会获得一个进程本地的 ID 号。

      • 每个 CPU 设备的 process_index 现在将与同一进程中的任何 GPU 或 TPU 设备匹配;先前 CPU 设备的 process_index 始终是 0。

    • 在 NVIDIA GPU 上,JAX 现在优先为最大 1024x1024 的矩阵选择 Jacobi SVD 求解器。Jacobi 求解器似乎比非 Jacobi 版本更快。

  • 错误修复

    • 修复了当非有限值的数组传递给非对称特征值分解时发生的错误/挂起(#18226)。具有非有限值的数组现在会产生全为 NaN 的数组作为输出。

jax 0.4.20 (2023年11月2日)#

jaxlib 0.4.20 (2023年11月2日)#

  • 错误修复

    • 修复了 E4M3 和 E5M2 浮点类型之间的某些类型混淆问题。

jax 0.4.19 (2023年10月19日)#

  • 新功能

    • 添加了 jax.typing.DTypeLike,可用于注解可转换为 JAX 数据类型的对象。

    • 添加了 jax.numpy.fill_diagonal

  • 更改

    • JAX 现在要求 SciPy 1.9 或更高版本。

  • 错误修复

    • 在多控制器分布式 JAX 程序中,只有 0 号进程会写入持久编译缓存条目。这修复了当缓存位于 GCS 等网络文件系统时发生的写入争用问题。

    • 版本检查 for cusolver 和 cufft 不再考虑补丁版本,以确定安装版本是否不低于 JAX 构建所用的版本。

jaxlib 0.4.19 (2023年10月19日)#

  • 更改

    • 如果已安装 pip 安装的 NVIDIA CUDA 库(nvidia-... 包),jaxlib 将始终优先使用它们,而不是 LD_LIBRARY_PATH 中命名的任何其他 CUDA 安装。如果这引起问题,并且意图是使用系统安装的 CUDA,则解决方案是移除 pip 安装的 CUDA 库包。

jax 0.4.18 (2023年10月6日)#

jaxlib 0.4.18 (2023年10月6日)#

  • 更改

    • CUDA jaxlibs 现在要求用户安装兼容的 NCCL 版本。如果使用推荐的 cuda12_pip 安装,NCCL 应该会自动安装。目前需要 NCCL 2.16 或更高版本。

    • 我们现在提供 Linux aarch64 wheel,包括带 NVIDIA GPU 支持和不带 NVIDIA GPU 支持的。

    • jax.Array.item() 现在支持可选的索引参数。

  • 弃用

    • jax.lax 中的一些内部实用程序和意外导出的内容已弃用,将在未来版本中移除。

      • jax.lax.dtypes:请改用 jax.dtypes

      • jax.lax.itertools:请改用 itertools

      • naryop, naryop_dtype_rule, standard_abstract_eval, standard_naryop, standard_primitive, standard_unop, unop,以及 unop_dtype_rule 是内部实用程序,现已弃用且无替代。

  • 错误修复

    • 修复了 Cloud TPU 回归问题,该问题导致编译因 smem 而 OOM。

jax 0.4.17 (2023年10月3日)#

  • 新功能

  • 弃用

    • 移除了已弃用的模块 jax.abstract_arrays 及其所有内容。

    • jax.random 中的命名密钥构造函数已弃用。请改用向 jax.random.PRNGKey()jax.random.key() 传递 impl 参数。

      • random.threefry2x32_key(seed) 变为 random.PRNGKey(seed, impl='threefry2x32')

      • random.rbg_key(seed) 变为 random.PRNGKey(seed, impl='rbg')

      • random.unsafe_rbg_key(seed) 变为 random.PRNGKey(seed, impl='unsafe_rbg')

  • 更改

    • CUDA:JAX 现在验证找到的 CUDA 库不旧于 JAX 构建所用的 CUDA 库。如果找到较旧的库,JAX 会引发异常,这比神秘的故障和崩溃更好。

    • 移除了“未找到 GPU/TPU”警告。现在,如果在 Linux 上找到 NVIDIA GPU 或 Google TPU 但未使用,并且未指定 --jax_platforms,则会发出警告。

    • jax.scipy.stats.mode() 现在如果是在大小为 0 的轴上取众数,则返回 0 计数,这与 SciPy 1.11 中 scipy.stats.mode 的行为一致。

    • 大多数 jax.numpy 函数和属性现在具有完全定义的类型存根。以前,许多这些都被静态类型检查器(如 mypypytype)视为 Any

jaxlib 0.4.17 (2023年10月3日)#

  • 更改

    • 此版本添加了 Python 3.12 wheel。

    • CUDA 12 wheel 现在需要 CUDA 12.2 或更高版本以及 cuDNN 8.9.4 或更高版本。

  • 错误修复

    • 修复了初始化 JAX CPU 后端时 ABSL 的日志垃圾问题。

jax 0.4.16 (2023年9月18日)#

  • 更改

    • 添加了 jax.numpy.ufunc,以及 jax.numpy.frompyfunc(),可以将任何标量值函数转换为 numpy.ufunc() 风格的对象,具有 outer(), reduce(), accumulate(), at(),和 reduceat() 等方法(#17054)。

    • 添加了 jax.scipy.integrate.trapezoid()

    • 在非 IPython 环境下运行时:当引发异常时,JAX 现在会从堆栈跟踪中过滤掉其内部帧。 (之前出现的“未过滤堆栈跟踪”已消失。)这应该会产生更友好的堆栈跟踪。有关示例,请参阅 此处。此行为可以通过设置 JAX_TRACEBACK_FILTERING=remove_frames(对于两个独立的未过滤/过滤的堆栈跟踪,这是旧的行为)或 JAX_TRACEBACK_FILTERING=off(对于一个未过滤的堆栈跟踪)来更改。

    • jax2tf 默认序列化版本现在是 7,这引入了新的形状 安全断言

    • 传递给 jax.sharding.Mesh 的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices() 已经是可哈希的。

  • 重大更改

    • jax2tf 现在默认使用原生序列化。有关详细信息和覆盖默认值的机制,请参阅 jax2tf 文档

    • 选项 --jax_coordination_service 已被移除。它现在始终为 True

    • jax.jaxpr_util 已从公共 JAX 命名空间中移除。

    • JAX_USE_PJRT_C_API_ON_TPU 不再生效(即,它始终默认为 true)。

    • 向后兼容标志 --jax_host_callback_ad_transforms(于 2021 年 12 月引入)已移除。

  • 弃用

    • 一些 jax.numpy API 已根据 NumPy NEP-52 进行弃用。

      • jax.numpy.NINF 已弃用。请改用 -jax.numpy.inf

      • jax.numpy.PZERO 已弃用。请改用 0.0

      • jax.numpy.NZERO 已弃用。请改用 -0.0

      • jax.numpy.issubsctype(x, t) 已弃用。请改用 jax.numpy.issubdtype(x.dtype, t)

      • jax.numpy.row_stack 已弃用。请改用 jax.numpy.vstack

      • jax.numpy.in1d 已弃用。请改用 jax.numpy.isin

      • jax.numpy.trapz 已弃用。请改用 jax.scipy.integrate.trapezoid

    • jax.scipy.linalg.triljax.scipy.linalg.triu 已弃用,与 SciPy 一致。请改用 jax.numpy.triljax.numpy.triu

    • jax.lax.prod 已移除,此前在 JAX v0.4.11 中已弃用。请改用内置的 math.prod

    • 来自 jax.interpreters.xla 中与为自定义 JAX 原始类型定义 HLO 降低规则相关的许多导出内容已弃用。自定义原始类型应使用 jax.interpreters.mlir 中的 StableHLO 降低实用程序来定义。

    • 以下先前已弃用的函数已在三个月后移除:

      • jax.abstract_arrays.ShapedArray:请改用 jax.core.ShapedArray

      • jax.abstract_arrays.raise_to_shaped:请改用 jax.core.raise_to_shaped

      • jax.numpy.alltrue:请改用 jax.numpy.all

      • jax.numpy.sometrue:请改用 jax.numpy.any

      • jax.numpy.product:请改用 jax.numpy.prod

      • jax.numpy.cumproduct:请改用 jax.numpy.cumprod

  • 弃用/移除

    • 内部子模块 jax.prng 已弃用。其内容可在 jax.extend.random 中找到。

    • 内部子模块路径 jax.linear_util 已弃用。请改用 jax.extend.linear_util。(属于 jax.extend:一个用于扩展的模块

    • jax.random.PRNGKeyArrayjax.random.KeyArray 已弃用。请使用 jax.Array 进行类型注解,并使用 jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) 进行运行时检测 prng 密钥类型。

    • 方法 PRNGKeyArray.unsafe_raw_array 已弃用。请改用 jax.random.key_data()

    • jax.experimental.pjit.with_sharding_constraint 已弃用。请改用 jax.lax.with_sharding_constraint

    • 内部实用程序 jax.core.is_opaque_dtypejax.core.has_opaque_dtype 已移除。不透明数据类型已重命名为扩展数据类型;请改用 jnp.issubdtype(dtype, jax.dtypes.extended)(自 jax v0.4.14 起可用)。

    • 实用程序 jax.interpreters.xla.register_collective_primitive 已移除。该实用程序在最近的 JAX 版本中没有实际作用,可以安全地移除对它的调用。

    • 内部子模块路径 jax.linear_util 已弃用。请改用 jax.extend.linear_util。(属于 jax.extend:一个用于扩展的模块

jaxlib 0.4.16 (2023年9月18日)#

  • 更改

    • 通过实验性的 jax sparse API 进行的稀疏 CSR 矩阵乘法在 NVIDIA GPU 上不再使用确定性算法。此更改是为了提高与 CUDA 12.2.1 的兼容性。

  • 错误修复

    • 修复了 Windows 上由于 LLVM 致命错误(与乱序节和 IMAGE_REL_AMD64_ADDR32NB 重定位有关)导致的崩溃(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。

jax 0.4.14 (2023年7月27日)#

  • 更改

    • jax.jit 接受 donate_argnames 作为参数。其语义类似于 static_argnames。如果未提供 donate_argnums 或 donate_argnames,则不捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,或反之,JAX 将使用 inspect.signature(fun) 来查找与 donate_argnames 对应的任何位置参数(或反之)。如果同时提供了 donate_argnums 和 donate_argnames,则不使用 inspect.signature,并且只捐赠在 donate_argnums 或 donate_argnames 中列出的实际参数。

    • jax.random.gamma() 已重构为更高效的算法,并具有更健壮的端点行为(#16779)。这意味着对于给定的 key,返回值的序列在 JAX v0.4.13 和 v0.4.14 之间对于 gamma 和相关采样器(包括 jax.random.ball()jax.random.beta()jax.random.chisquare()jax.random.dirichlet()jax.random.generalized_normal()jax.random.loggamma()jax.random.t())会有所改变。

  • 删除

    • 自 3 个月前弃用以来,in_axis_resourcesout_axis_resources 已从 pjit 中删除。请使用 in_shardingsout_shardings 作为替代。这是一个安全且简单的名称替换。它不会改变任何当前的 pjit 语义,也不会破坏任何代码。您仍然可以将 PartitionSpecs 传递给 in_shardings 和 out_shardings。

  • 弃用

    • 根据 https://jax.net.cn/en/latest/deprecation.html,已停止支持 Python 3.8。

    • 根据 https://jax.net.cn/en/latest/deprecation.html,JAX 现在要求 NumPy 1.22 或更高版本。

    • jax.numpy.ndarray.at 中按位置传递可选参数不再支持,此前已在 JAX 版本 0.4.7 中弃用。例如,应使用 x.at[i].get(indices_are_sorted=True) 而不是 x.at[i].get(True)

    • 以下 jax.Array 方法已移除,此前已在 JAX v0.4.5 中弃用:

    • 以下 API 在先前弃用后已移除:

      • jax.ad:请改用 jax.interpreters.ad

      • jax.curry:请改用 curry = lambda f: partial(partial, f)

      • jax.partial_eval:请改用 jax.interpreters.partial_eval

      • jax.pxla:请改用 jax.interpreters.pxla

      • jax.xla:请改用 jax.interpreters.xla

      • jax.ShapedArray:请改用 jax.core.ShapedArray

      • jax.interpreters.pxla.device_put:请改用 jax.device_put()

      • jax.interpreters.pxla.make_sharded_device_array:请改用 jax.make_array_from_single_device_arrays()

      • jax.interpreters.pxla.ShardedDeviceArray:请改用 jax.Array

      • jax.numpy.DeviceArray:请改用 jax.Array

      • jax.stages.Compiled.compiler_ir:请改用 jax.stages.Compiled.as_text()

  • 重大更改

    • JAX 现在要求 ml_dtypes 版本 0.2.0 或更高版本。

    • 为了修复一个边缘情况,带有五个参数的 jax.lax.cond() 调用将始终解析为“公共操作数” cond 行为(如文档所述),即使其他操作数也为可调用。请参阅 #16413。

    • 已移除已弃用的配置选项 jax_arrayjax_jit_pjit_api_merge,它们已无效。这些选项在许多版本中已默认为 true。

  • 新功能

    • JAX 现在支持配置标志 –jax_serialization_version 和 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本(#16746)。

    • 在存在形状多态性的情况下,jax2tf 现在生成检查某些形状约束的代码,前提是序列化版本至少为 7。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。

jaxlib 0.4.14 (2023年7月27日)#

  • 弃用

    • 根据 https://jax.net.cn/en/latest/deprecation.html,已停止支持 Python 3.8。

jax 0.4.13 (2023年6月22日)#

  • 更改

    • jax.jit 现在允许将 None 传递给 in_shardingsout_shardings。语义如下:

      • 对于 in_shardings,JAX 将将其标记为复制,但此行为将来可能会更改。

      • 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

    • jax.experimental.pjit.pjit 也允许将 None 传递给 in_shardingsout_shardings。语义如下:

      • 如果未提供 mesh 上下文管理器,JAX 可以自由选择任何分片。

        • 对于 in_shardings,JAX 将将其标记为复制,但此行为将来可能会更改。

        • 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

      • 如果提供了 mesh 上下文管理器,None 将意味着该值将在 mesh 的所有设备上复制。

    • Executable.cost_analysis() 在 Cloud TPU 上可用

    • 添加了一个警告,如果使用了非白名单的 jaxlib 插件。

    • 添加了 jax.tree_util.tree_leaves_with_path

    • None 不能作为 jax.experimental.multihost_utils.host_local_array_to_global_arrayjax.experimental.multihost_utils.global_array_to_host_local_array 的有效输入。如果您想复制输入,请使用 jax.sharding.PartitionSpec()

  • 错误修复

    • 修复了 CUDA 12 版本中的 wheel 名称不正确的问题(#16362);正确的 wheel 名称是 cudnn89 而不是 cudnn88

  • 弃用

    • 对于 jax.experimental.jax2tf.convert(),参数 native_serialization_strict_checks 已弃用,取而代之的是新的 native_serializaation_disabled_checks(#16347)。

jaxlib 0.4.13 (2023年6月22日)#

  • 更改

    • jaxlib Pypi 发布中添加了 Windows CPU-only wheel。

  • 错误修复

    • __cuda_array_interface__ 在之前的 jaxlib 版本中存在问题,现已修复(#16440)。

    • 在 NVIDIA GPU 上,并发 CUDA 内核跟踪现已默认启用。

jax 0.4.12 (2023年6月8日)#

  • 更改

    • 添加了 scipy.spatial.transform.Rotationscipy.spatial.transform.Slerp

  • 弃用

    • jax.abstract_arrays 及其内容已弃用。请参阅 jax.core 中的相关功能。

    • jax.numpy.alltrue:请改用 jax.numpy.all。这遵循 NumPy 版本 1.25.0 中 numpy.alltrue 的弃用。

    • jax.numpy.sometrue:请改用 jax.numpy.any。这遵循 NumPy 版本 1.25.0 中 numpy.sometrue 的弃用。

    • jax.numpy.product:请改用 jax.numpy.prod。这遵循 NumPy 版本 1.25.0 中 numpy.product 的弃用。

    • jax.numpy.cumproduct:请改用 jax.numpy.cumprod。这遵循 NumPy 版本 1.25.0 中 numpy.cumproduct 的弃用。

    • jax.sharding.OpShardingSharding 已移除,因其已弃用 3 个月。

jaxlib 0.4.12 (2023年6月8日)#

  • 更改

    • 包含 Hopper (SM version 9.0+) GPU 的 PTX/SASS。早期版本的 jaxlib 应该可以在 Hopper 上运行,但第一次执行 JAX 操作时会出现长时间的 JIT 编译延迟。

  • 错误修复

    • 修复了 Python 3.11 下 JAX 生成的 Python 堆栈跟踪中的源行信息不正确的问题。

    • 修复了打印 JAX 生成的 Python 堆栈跟踪中的帧局部变量时发生的崩溃问题(#16027)。

jax 0.4.11 (2023年5月31日)#

  • 弃用

    • 根据 API 兼容性 策略,以下 API 在 3 个月后移除,遵循了弃用期:

      • jax.experimental.PartitionSpec:请改用 jax.sharding.PartitionSpec

      • jax.experimental.maps.Mesh:请改用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding:请改用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec:请改用 jax.sharding.PartitionSpec

      • jax.experimental.pjit.FROM_GDA。请改用将分片 jax.Array 对象作为输入,并移除 pjit 的可选 in_shardings 参数。

      • jax.interpreters.pxla.PartitionSpec:请改用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh:请改用 jax.sharding.Mesh

      • jax.interpreters.xla.Buffer:请改用 jax.Array

      • jax.interpreters.xla.Device:请改用 jax.Device

      • jax.interpreters.xla.DeviceArray:请改用 jax.Array

      • jax.interpreters.xla.device_put:请改用 jax.device_put

      • jax.interpreters.xla.xla_call_p:请改用 jax.experimental.pjit.pjit_p

      • 已移除 with_sharding_constraintaxis_resources 参数。请改用 shardings

jaxlib 0.4.11 (2023年5月31日)#

  • 更改

    • Devices 添加了 memory_stats() 方法。如果支持,它将返回一个字典,其中包含字符串统计名称和整数值,例如 "bytes_in_use",或者如果平台不支持内存统计信息,则返回 None。确切的统计信息返回值可能因平台而异。目前仅在 Cloud TPU 上实现。

    • 重新添加了对 CPU 设备上的 Python 缓冲区协议(memoryview)的支持。

jax 0.4.10 (2023年5月11日)#

jaxlib 0.4.10 (2023年5月11日)#

  • 更改

    • 修复了导致先前版本无法在 Mac M1 上运行的 'apple-m1' is not a recognized processor for this target (ignoring processor) 问题。

jax 0.4.9 (2023年5月9日)#

  • 更改

    • 实验性标志 experimental_cpp_jit, experimental_cpp_pjit 和 experimental_cpp_pmap 已移除。它们现在始终开启。

    • TPU 上的奇异值分解 (SVD) 的精度已提高(需要 jaxlib 0.4.9)。

  • 弃用

    • jax.experimental.gda_serialization 已弃用,并已重命名为 jax.experimental.array_serialization。请更改您的导入以使用 jax.experimental.array_serialization

    • pjit 的 in_axis_resourcesout_axis_resources 参数已弃用。请分别使用 in_shardingsout_shardings

    • 函数 jax.numpy.msort 已移除。它自 JAX v0.4.1 起已弃用。请改用 jnp.sort(a, axis=0)

    • 因其仅与 sharded_jit 一起使用,in_partsout_parts 参数已从 jax.xla_computation 中移除,而 sharded_jit 已不存在。

    • 因长期未使用,instantiate_const_outputs 参数已从 jax.xla_computation 中移除。

jaxlib 0.4.9 (2023年5月9日)#

jax 0.4.8 (2023年3月29日)#

  • 重大更改

    • Cloud TPU 运行时的一个主要组件已升级。这在 Cloud TPU 上启用了以下新功能:

      使用新运行时组件时,不再支持 jax.experimental.host_callback()。如果新的 jax.debug API 不足以满足您的用例,请在 JAX issue tracker 上提交问题。

      旧运行时组件至少在未来三个月内可用,方法是设置环境变量 JAX_USE_PJRT_C_API_ON_TPU=false。如果您发现出于任何原因需要禁用新运行时,请在 JAX issue tracker 上告知我们。

  • 更改

    • 最低 jaxlib 版本已从 0.4.6 提高到 0.4.7。

  • 弃用

    • 已停止支持 CUDA 11.4。JAX GPU wheel 仅支持 CUDA 11.8 和 CUDA 12。旧版 CUDA 可能可用,前提是 jaxlib 是从源代码构建的。

    • global_arg_shapes 参数的 pmap 仅适用于 sharded_jit,已从 pmap 中移除。请迁移到 pjit 并从 pmap 中移除 global_arg_shapes。

jax 0.4.7 (2023年3月27日)#

  • 更改

    • 根据 https://jax.net.cn/en/latest/jax_array_migration.html#jax-array-migration,jax.config.jax_array 已无法禁用。

    • jax.config.jax_jit_pjit_api_merge 已无法禁用。

    • jax.experimental.jax2tf.convert() 现在支持 native_serialization 参数,以使用 JAX 的原生降低到 StableHLO 来获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原始类型降低到 TensorFlow op。这简化了内部结构,并增加了对所序列化内容与 JAX 原生语义匹配的信心。请参阅 文档。作为此更改的一部分,配置标志 --jax2tf_default_experimental_native_lowering 已重命名为 --jax2tf_native_serialization

    • JAX 现在依赖 ml_dtypes,其中包含 bfloat16 等 NumPy 类型的定义。这些定义以前是 JAX 的内部内容,但已拆分到一个单独的包中,以便与其他项目共享。

    • JAX 现在要求 NumPy 1.21 或更新版本,以及 SciPy 1.7 或更新版本。

  • 弃用

    • 类型 jax.numpy.DeviceArray 已弃用。请改用 jax.Array,它是其别名。

    • 类型 jax.interpreters.pxla.ShardedDeviceArray 已弃用。请改用 jax.Array

    • jax.numpy.ndarray.at 中按位置传递附加参数已弃用。例如,应使用 x.at[i].get(indices_are_sorted=True) 而不是 x.at[i].get(True)

    • jax.interpreters.xla.device_put 已弃用。请使用 jax.device_put

    • jax.interpreters.pxla.device_put 已弃用。请使用 jax.device_put

    • jax.experimental.pjit.FROM_GDA 已弃用。请将分片 jax.Arrays 作为输入传递,并移除 pjit 的 in_shardings 参数,因为它不是必需的。

jaxlib 0.4.7 (2023年3月27日)#

更改

  • jaxlib 现在依赖 ml_dtypes,其中包含 bfloat16 等 NumPy 类型的定义。这些定义以前是 JAX 的内部内容,但已拆分到一个单独的包中,以便与其他项目共享。

jax 0.4.6 (2023年3月9日)#

  • 更改

    • jax.tree_util 现在包含一组 API,允许用户定义自定义 Pytree 节点的键。这包括:

      • tree_flatten_with_path,它扁平化一个树,并返回每个叶子及其键路径。

      • tree_map_with_path,它可以映射一个接受键路径作为参数的函数。

      • register_pytree_with_keys,用于注册自定义 Pytree 节点的键路径和叶子应如何显示。

      • keystr,用于美观地打印键路径。

    • jax2tf.call_tf() 有一个新参数 output_shape_dtype(默认为 None),可用于声明结果的输出形状和类型。这使得 jax2tf.call_tf() 能够在形状多态性的情况下工作。(#14734)。

  • 弃用

    • jax.tree_util 中的旧键路径 API 已弃用,并将于 2023 年 3 月 10 日起三个月内移除。

jaxlib 0.4.6 (2023年3月9日)#

jax 0.4.5 (2023年3月2日)#

  • 弃用

    • jax.sharding.OpShardingSharding 已重命名为 jax.sharding.GSPMDShardingjax.sharding.OpShardingSharding 将在 2023 年 2 月 17 日起三个月后移除。

    • 以下 jax.Array 方法已弃用,并将于 2023 年 2 月 23 日起三个月后移除:

jax 0.4.4 (2023年2月16日)#

  • 更改

    • jit 和 pjit 的实现已合并。合并 jit 和 pjit 更改了 JAX 的内部结构,而不影响 JAX 的公共 API。以前,jit 是一个最终风格的原始类型。最终风格意味着 jaxpr 的创建被尽可能延迟,转换堆叠在一起。随着 jit-pjit 实现的合并,jit 成为一个初始风格的原始类型,这意味着我们尽早将转换到 jaxpr。有关更多信息,请参阅 autodidax 中的 此部分。切换到初始风格应该会简化 JAX 的内部结构,并使动态形状等功能的开发更加容易。您只能通过环境变量禁用它,即 os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。由于合并在导入时影响 JAX,因此必须在导入 jax 之前通过环境变量禁用它。

    • 已弃用 with_sharding_constraintaxis_resources 参数。请改用 shardings。如果您将 axis_resources 用作参数,则无需更改。如果您将其用作关键字参数,请改用 shardingsaxis_resources 将在 2023 年 2 月 13 日起三个月后移除。

    • 添加了 jax.typing 模块,其中包含用于 JAX 函数类型注解的工具。

    • 以下名称已弃用:

      • jax.xla.Devicejax.interpreters.xla.Device:请改用 jax.Device

      • jax.experimental.maps.Mesh。请改用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding:请改用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec:请改用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh:请改用 jax.sharding.Mesh

      • jax.interpreters.pxla.PartitionSpec:请改用 jax.sharding.PartitionSpec

  • 重大更改

    • jax.numpy.sum() 等约简函数中的 initial 参数现在被要求为标量,这与相应的 NumPy API 一致。先前将非标量 initial 值广播到输出的行为是一个非故意的实现细节(#14446)。

jaxlib 0.4.4 (2023年2月16日)#

  • 重大更改

    • 默认 jaxlib 构建已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以从源代码构建 jaxlib(通过 build.py--cuda_compute_capabilities=sm_35 选项),但请注意,CUDA 12 已完全停止支持 Kepler GPU。

jax 0.4.3 (2023年2月8日)#

jaxlib 0.4.3 (2023年2月8日)#

  • jax.Array 现在具有非阻塞 is_ready() 方法,该方法返回 True 如果数组已准备就绪(另请参见 jax.block_until_ready())。

jax 0.4.2 (2023年1月24日)#

  • 重大更改

    • 删除了 jax.experimental.callback

    • 涉及 jax2tf 形状多态性的维度操作已推广到更多场景,方法是将符号维度转换为 JAX 数组。涉及符号维度和 np.ndarray 的操作现在可能在将结果用作形状值时引发错误(#14106)。

    • jaxpr 对象现在对属性设置引发错误,以避免问题性突变(#14102)。

  • 更改

    • jax2tf.call_tf() 有一个新参数 has_side_effects(默认为 True),可用于声明实例是否可以被 JAX 优化(如死代码消除)移除或复制(#13980)。

    • 为 jax2tf 形状多态性添加了对 floordiv 和 mod 的更多支持。先前,某些除法操作在存在符号维度时会导致错误(#14108)。

jaxlib 0.4.2 (2023年1月24日)#

  • 更改

    • 将 JAX_USE_PJRT_C_API_ON_TPU=1 设置为启用新的 Cloud TPU 运行时,该运行时具有自动设备内存碎片整理功能。

jax 0.4.1 (2022年12月13日)#

  • 更改

    • 根据 JAX 的 Python 和 NumPy 版本支持策略,已停止对 Python 3.7 的支持。

    • 我们引入了 jax.Array,它是一种统一的数组类型,包含了 JAX 中的 DeviceArrayShardedDeviceArrayGlobalDeviceArray 类型。jax.Array 类型有助于使并行化成为 JAX 的核心功能,简化和统一 JAX 内部结构,并允许我们统一 jitpjitjax.Array 在 JAX 0.4 中已默认启用,并对 pjit API 造成了一些破坏性更改。jax.Array 迁移指南 可以帮助您将代码库迁移到 jax.Array。您还可以参考 分布式数组和自动并行化 教程来理解新概念。

    • PartitionSpecMesh 现在已不再是实验性的。新的 API 端点是 jax.sharding.PartitionSpecjax.sharding.Meshjax.experimental.maps.Meshjax.experimental.PartitionSpec 已弃用,将在 3 个月后移除。

    • with_sharding_constraint 的新公共端点是 jax.lax.with_sharding_constraint

    • 如果将 ABSL 标志与 jax.config 一起使用,则在 JAX 配置选项首次从 ABSL 标志加载后,ABSL 标志的值不再读取或写入。此更改提高了读取 jax.config 选项的性能,因为 jax.config 选项在 JAX 中使用非常广泛。

    • jax2tf 现在使用与嵌入式 JAX 计算相同平台的第一个 TF 设备进行 TF 降低。之前,它为 JAX 默认后端使用第 0 个设备。

    • 许多 jax.numpy 函数现在已将其参数标记为仅限位置参数,与 NumPy 一致。

    • jnp.msort 已弃用,遵循 numpy 1.24 中 np.msort 的弃用。它将在未来的版本中移除,遵循 API 兼容性 策略。它可以替换为 jnp.sort(a, axis=0)

jaxlib 0.4.1 (2022年12月13日)#

  • 更改

    • 根据 JAX 的 Python 和 NumPy 版本支持策略,已停止对 Python 3.7 的支持。

    • 更改了 XLA_PYTHON_CLIENT_MEM_FRACTION=.XX 的行为,以分配总 GPU 内存的 XX% 来代替先前使用当前可用 GPU 内存计算预分配的行为。有关详细信息,请参阅 GPU 内存分配

    • 已移除已弃用的方法 .block_host_until_ready()。请改用 .block_until_ready()

jax 0.4.0 (2022年12月12日)#

  • 此版本已被撤回。

jaxlib 0.4.0 (2022年12月12日)#

  • 此版本已被撤回。

jax 0.3.25 (2022年11月15日)#

jaxlib 0.3.25 (2022年11月15日)#

  • 更改

    • 增加了对 CPU 和 GPU 上三对角线约简的支持。

    • 增加了对 CPU 上上三对角线约简的支持。

  • 错误

    • 修复了一个错误,该错误导致 JAX 捕获的堆栈跟踪中的帧在 Python 3.10+ 下与源行映射不正确。

jax 0.3.24 (2022年11月4日)#

  • 更改

    • JAX 的导入速度应该更快了。我们现在惰性导入 scipy,这占了 JAX 导入时间的很大一部分。

    • 设置环境变量 JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N 可用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间为 1 秒或更长的计算将被缓存。

    • 默认设备顺序由 pmap 在 TPU 上使用(如果未指定顺序)现在匹配单进程作业的 jax.devices()。先前这两个顺序不同,可能导致不必要的复制或内存不足错误。要求顺序一致简化了问题。

  • 重大更改

    • jax.numpy.gradient() 现在与 jax.numpy 中的大多数其他函数行为一致,禁止在需要数组的地方传入列表或元组(#12958)。

    • jax.numpy.linalgjax.numpy.fft 中的函数现在统一要求输入为类数组(array-like):即不能用列表和元组代替数组。这是 #7737 的一部分。

  • 弃用

    • jax.sharding.MeshPspecSharding 已重命名为 jax.sharding.NamedShardingjax.sharding.MeshPspecSharding 名称将在 3 个月后移除。

jaxlib 0.3.24 (2022 年 11 月 4 日)#

  • 更改

    • Buffer donation 现在可以在 CPU 上使用。这可能会破坏已将 buffer 标记为 donation 但依赖于 donation 未实现的代码。

jax 0.3.23 (2022 年 10 月 12 日)#

  • 更改

    • 为新的 jaxlib 发布更新 Colab TPU 驱动程序版本。

jax 0.3.22 (2022 年 10 月 11 日)#

  • 更改

    • 在 TPU 初始化时,默认设置 JAX_PLATFORMS=tpu,cpu,因此如果 TPU 初始化失败,JAX 将引发错误而不是回退到 CPU。设置 JAX_PLATFORMS='' 可覆盖此行为并自动选择可用后端(原始默认值),或设置 JAX_PLATFORMS=cpu 以始终使用 CPU,无论 TPU 是否可用。

  • 弃用

    • JAX v0.3.8 中弃用的几个测试工具已从 jax.test_util 中移除。

jaxlib 0.3.22 (2022 年 10 月 11 日)#

jax 0.3.21 (2022 年 9 月 30 日)#

  • GitHub commits.

  • 更改

    • 持久化编译缓存现在会在出错时发出警告而不是引发异常(#12582),因此如果缓存出现问题,程序可以继续执行。设置 JAX_RAISE_PERSISTENT_CACHE_ERRORS=true 可恢复此行为。

jax 0.3.20 (2022 年 9 月 28 日)#

  • 错误修复

    • 添加了上个版本中缺失的 .pyi 文件(#12536)。

    • 修复了 jax 0.3.19 与其固定的 libtpu 版本之间的不兼容性(#12550)。需要 jaxlib 0.3.20。

    • 修复了 setup.py 注释中错误的 pip URL(#12528)。

jaxlib 0.3.20 (2022 年 9 月 28 日)#

  • GitHub commits.

  • 错误修复

    • 修复了在分布式作业中通过 jax_cuda_visible_devices 限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成是必需的(#12533)。

jax 0.3.19 (2022 年 9 月 27 日)#

jax 0.3.18 (2022 年 9 月 26 日)#

  • GitHub commits.

  • 更改

    • 提前(Ahead-of-time)降低和编译功能(在 #7733 中跟踪)现在稳定且公开。请参阅 概述jax.stages 的 API 文档。

    • 引入了 jax.Array,用于 isinstance 检查和 JAX 中数组类型的类型注解。请注意,这包括了 jax.numpy.ndarray 对于 jax 内部对象的 `isinstance` 工作方式的一些细微变化,因为 jax.numpy.ndarray 现在是 jax.Array 的简单别名。

  • 重大更改

    • jax._src 不再导入到公共 jax 命名空间。这可能会破坏使用 JAX 内部功能的用户的代码。

    • jax.soft_pmap 已删除。请使用 pjitxmap 代替。 jax.soft_pmap 未被文档化。如果它被文档化,会提供一个弃用期。

jax 0.3.17 (2022 年 8 月 31 日)#

  • GitHub commits.

  • 错误

    • 修复了 lax.pow 指数为零时梯度的边缘情况问题(#12041)。

  • 重大更改

    • jax.checkpoint()(也称为 jax.remat())不再支持 concrete 选项,遵循先前版本的弃用;请参阅 JEP 11830

  • 更改

    • 添加了 jax.pure_callback(),它允许从编译后的函数(例如,用 jax.jitjax.pmap 装饰的函数)调用纯 Python 函数。

  • 弃用

    • 已弃用的 DeviceArray.tile() 方法已被移除。请使用 jax.numpy.tile()#11944)。

    • DeviceArray.to_py() 已被弃用。请使用 np.asarray(x) 代替。

jax 0.3.16#

jax 0.3.15 (2022 年 7 月 22 日)#

jaxlib 0.3.15 (2022 年 7 月 22 日)#

jax 0.3.14 (2022 年 6 月 27 日)#

  • GitHub commits.

  • 重大更改

    • jax.experimental.compilation_cache.initialize_cache() 不再支持 max_cache_size_  bytes,也不会将其作为输入。

    • JAX_PLATFORMS 现在在平台初始化失败时会引发异常。

  • 更改

    • 修复了与 NumPy 1.23 的兼容性问题。

    • jax.numpy.linalg.slogdet() 现在接受一个可选的 method 参数,允许在基于 LU 分解的实现和基于 QR 分解的实现之间进行选择。

    • jax.numpy.linalg.qr() 现在支持 mode="raw"

    • picklecopy.copycopy.deepcopy 现在对 JAX 数组的使用提供了更完整的支持(#10659)。具体来说:

      • pickledeepcopy 之前在使用 DeviceArray 时会返回 np.ndarray 对象;现在会返回 DeviceArray 对象。对于 deepcopy,复制的数组在与原始数组相同的设备上。对于 pickle,反序列化的数组将位于默认设备上。

      • 在函数转换(即跟踪代码)中,deepcopycopy 之前是空操作。现在它们使用与 DeviceArray.copy() 相同的机制。

      • 对跟踪数组调用 pickle 现在会导致显式的 ConcretizationTypeError

    • 在 TPU 上,奇异值分解(SVD)和对称/厄米特征值分解的实现应该会显著加快,特别是对于大于 1000x1000 的矩阵。两者现在都使用谱分治算法进行特征值分解(QDWH-eig)。

    • jax.numpy.ldexp() 不再默默地将所有输入提升为 float64,而是将 int32 或更小大小的整数输入提升为 float32(#10921)。

    • jax.profiler.start_trace()jax.profiler.start_trace() 添加了 create_perfetto_link 选项。使用时,profiler 将生成一个指向 Perfetto UI 的链接以查看跟踪。

    • 更改了 jax.profiler.start_server(...)() 的语义,以将 keepalive 全局存储,而不是要求用户保留其引用。

    • 添加了 jax.random.generalized_normal()

    • 添加了 jax.random.ball()

    • 添加了 jax.default_device()

    • 添加了一个 python -m jax.collect_profile 脚本,用于手动捕获程序跟踪,作为 TensorBoard UI 的替代方法。

    • 添加了一个 jax.named_scope 上下文管理器,用于将 profiler 元数据添加到 Python 程序(类似于 jax.named_call)。

    • 在 scatter-update 操作(即 jax.numpy.ndarray.at)中,不安全的隐式 dtype 转换已被弃用,现在会导致 FutureWarning。在未来的版本中,这将成为一个错误。一个不安全隐式转换的例子是 jnp.zeros(4, dtype=int).at[0].set(1.5),其中 1.5 之前会被静默截断为 1

    • jax.experimental.compilation_cache.initialize_cache() 现在支持 gcs bucket 路径作为输入。

    • 添加了 jax.scipy.stats.gennorm()

    • jax.numpy.roots()strip_zeros=False 且系数前导零时行为更好(#11215)。

jaxlib 0.3.14 (2022 年 6 月 27 日)#

  • GitHub commits.

    • x86-64 Mac wheel 现在要求 Mac OS 10.14 (Mojave) 或更新版本。Mac OS 10.14 发布于 2018 年,因此这应该不是一个很苛刻的要求。

    • 更新了 NCCL 的捆绑版本到 2.12.12,修复了一些死锁问题。

    • Python flatbuffers 包不再是 jaxlib 的依赖项。

jax 0.3.13 (2022 年 5 月 16 日)#

jax 0.3.12 (2022 年 5 月 15 日)#

jax 0.3.11 (2022 年 5 月 15 日)#

  • GitHub commits.

  • 更改

    • jax.lax.eigh() 现在接受一个可选的 sort_eigenvalues 参数,允许用户选择在 TPU 上不进行特征值排序。

  • 弃用

    • jax.lax.linalg 中的函数现在将非数组参数标记为仅关键字参数。作为向后兼容步骤,通过位置传递关键字参数会产生警告,但在未来的 JAX 版本中,通过位置传递关键字参数将失败。然而,大多数用户应该更倾向于使用 jax.numpy.linalg

    • JAX 扩展到 scipy API 的 jax.scipy.linalg.polar_unitary() 已被弃用。请使用 jax.scipy.linalg.polar() 代替。

jax 0.3.10 (2022 年 5 月 3 日)#

jaxlib 0.3.10 (2022 年 5 月 3 日)#

  • GitHub commits.

  • 更改

    • TF commit 修复了 MHLO 规范化器中的一个问题,该问题导致某些程序的常量折叠花费很长时间或崩溃。

jax 0.3.9 (2022 年 5 月 2 日)#

  • GitHub commits.

  • 更改

    • 为 GlobalDeviceArray 添加了对完全异步 checkpointing 的支持。

jax 0.3.8 (2022 年 4 月 29 日)#

  • GitHub commits.

  • 更改

    • TPU 上的 jax.numpy.linalg.svd() 使用 qdwh-svd 求解器。

    • TPU 上的 jax.numpy.linalg.cond() 现在接受复数输入。

    • TPU 上的 jax.numpy.linalg.pinv() 现在接受复数输入。

    • TPU 上的 jax.numpy.linalg.matrix_rank() 现在接受复数输入。

    • 添加了 jax.scipy.cluster.vq.vq()

    • jax.experimental.maps.mesh 已删除。请使用 jax.experimental.maps.Mesh。请参阅 https://jax.net.cn/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh 获取更多信息。

    • jax.scipy.linalg.qr() 现在当 mode='r' 时返回一个长度为 1 的元组而不是原始数组,以匹配 scipy.linalg.qr 的行为(#10452)。

    • jax.numpy.take_along_axis() 现在接受一个可选的 mode 参数,该参数指定了越界索引的行为。默认情况下,对于越界索引将返回无效值(例如,NaN)。在早期版本的 JAX 中,无效索引会被裁剪到范围内。通过传递 mode="clip" 可以恢复之前的行为。

    • jax.numpy.take() 现在默认使用 mode="fill",它为越界索引返回无效值(例如,NaN)。

    • Scatter 操作,如 x.at[...].set(...),现在具有 "drop" 语义。这对 scatter 操作本身没有影响,但这意味着当进行微分时,scatter 的梯度将为越界索引返回零余切。之前越界索引会被裁剪到范围内以进行梯度计算,这在数学上是不正确的。

    • jax.numpy.take_along_axis() 现在如果其索引不是整数类型,则会引发 TypeError,这与 numpy.take_along_axis() 的行为一致。之前非整数索引会被静默转换为整数。

    • jax.numpy.ravel_multi_index() 现在如果其 dims 参数不是整数类型,则会引发 TypeError,这与 numpy.ravel_multi_index() 的行为一致。之前非整数 dims 会被静默转换为整数。

    • jax.numpy.split() 现在如果其 axis 参数不是整数类型,则会引发 TypeError,这与 numpy.split() 的行为一致。之前非整数 axis 会被静默转换为整数。

    • jax.numpy.indices() 现在如果其维度不是整数类型,则会引发 TypeError,这与 numpy.indices() 的行为一致。之前非整数维度会被静默转换为整数。

    • jax.numpy.diag() 现在如果其 k 参数不是整数类型,则会引发 TypeError,这与 numpy.diag() 的行为一致。之前非整数 k 会被静默转换为整数。

    • 添加了 jax.random.orthogonal()

  • 弃用

    • jax.test_util 中可用的许多函数和对象现在已被弃用,并在导入时发出警告。这包括 cases_from_listcheck_closecheck_eqdevice_under_testformat_shape_dtype_stringrand_uniformskip_on_deviceswith_configxla_bridge_default_tolerance#10389)。这些类,连同先前已弃用的 JaxTestCaseJaxTestLoaderBufferDonationTestCase,将在未来的 JAX 版本中移除。这些实用程序中的大多数可以被标准 Python 和 NumPy 测试实用程序替换,例如 unittestabsl.testingnumpy.testing 等。 JAX 特定的功能,如设备检查,可以通过使用公共 API(如 jax.devices())来替换。许多已弃用的实用程序仍将存在于 jax._src.test_util 中,但这些不是公共 API,因此在未来的版本中可能会被更改或移除,恕不另行通知。

jax 0.3.7 (2022 年 4 月 15 日)#

jaxlib 0.3.7 (2022 年 4 月 15 日)#

  • 更改

    • Linux wheels 现在按照 manylinux2014 标准构建,而不是 manylinux2010

jax 0.3.6 (2022 年 4 月 12 日)#

  • GitHub commits.

  • 更改

    • 升级了 libtpu wheel 的版本,修复了初始化 TPU pod 时出现的挂起问题。修复了 #10218

  • 弃用

    • jax.experimental.loops 已被弃用。请参阅 #10278 获取替代 API。

jax 0.3.5 (2022 年 4 月 7 日)#

jaxlib 0.3.5 (2022 年 4 月 7 日)#

  • 错误修复

    • 修复了双精度复数到实数 IRFFT 会在 GPU 上修改其输入 buffer 的错误(#9946)。

    • 修复了复数 scatter 的错误常量折叠(#10159)。

jax 0.3.4 (2022 年 3 月 18 日)#

jax 0.3.3 (2022 年 3 月 17 日)#

jax 0.3.2 (2022 年 3 月 16 日)#

  • GitHub commits.

  • 更改

    • 已弃用的函数 jax.ops.index_updatejax.ops.index_add(在 0.2.22 中弃用)已被移除。请使用 JAX 数组上的 .at 属性 代替,例如 x.at[idx].set(y)

    • jax.experimental.ann.approx_*_k 移至 jax.lax。这些函数是 jax.lax.top_k 的优化替代品。

    • jax.numpy.broadcast_arrays()jax.numpy.broadcast_to() 现在要求标量或类数组输入,并且在传入列表时会失败(#7737 的一部分)。

    • 标准的 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。

    • pjit 现在可以在 CPU 上工作(除了之前的 TPU 和 GPU 支持)。

jaxlib 0.3.2 (2022 年 3 月 16 日)#

  • 更改

    • XlaComputation.as_hlo_text() 现在可以通过传递布尔标志 print_large_constants=True 来支持打印大型常量。

  • 弃用

    • JAX 数组上的 .block_host_until_ready() 方法已被弃用。请使用 .block_until_ready() 代替。

jax 0.3.1 (2022 年 2 月 18 日)#

jax 0.3.0 (2022 年 2 月 10 日)#

jaxlib 0.3.0 (2022 年 2 月 10 日)#

  • 更改

    • 构建 jaxlib 现在需要 Bazel 5.0.0。

    • jaxlib 版本已升级到 0.3.0。请参阅 设计文档 了解解释。

jax 0.2.28 (2022 年 2 月 1 日)#

  • GitHub commits.

    • 默认情况下,jax.jit(f).lower(...).compiler_ir() 的输出为 MHLO 方言,如果未指定 dialect=

    • 现在,jax.jit(f).lower(...).compiler_ir(dialect='mhlo') 返回的是 MLIR ir.Module 对象,而不是其字符串表示。

jaxlib 0.1.76 (2022 年 1 月 27 日)#

  • 新功能

    • 包含 NVIDIA 计算能力 8.0 GPU(例如 A100)的预编译 SASS。移除了计算能力 6.1 的预编译 SASS,以免增加计算能力的数量:计算能力 6.1 的 GPU 可以使用 6.0 SASS。

    • 使用 jaxlib 0.1.76 时,JAX 默认使用 MHLO MLIR 方言作为其主要的后端编译器 IR。

  • 重大更改

    • 根据 弃用策略,已停止支持 NumPy 1.18。请升级到支持的 NumPy 版本。

  • 错误修复

    • 修复了由不同路径构建的、看似相同的 pytreedef 对象不相等的 bug (#9066)。

    • JAX jit 缓存需要两个静态参数具有相同的类型才能实现缓存命中 (#9311)。

jax 0.2.27 (2022 年 1 月 18 日)#

  • GitHub commits.

  • 重大更改

    • 根据 弃用策略,已停止支持 NumPy 1.18。请升级到支持的 NumPy 版本。

    • host_callback 原语已简化,移除了对 hcb.id_tap 和 id_print 的特殊 autodiff 处理。从现在起,只 tap primal。旧行为(在有限时间内)可以通过设置 JAX_HOST_CALLBACK_AD_TRANSFORMS 环境变量或 --jax_host_callback_ad_transforms 标志获得。此外,还添加了关于如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。

    • 排序现在与 NumPy 在 0.0NaN 上的行为一致,无论比特表示如何。特别是,0.0-0.0 现在被视为等效,而之前 -0.0 被视为小于 0.0。此外,所有 NaN 表示现在都被视为等效并排序到数组的末尾。之前负的 NaN 值排序到数组的前面,并且具有不同内部比特表示的 NaN 值不被视为等效,并根据这些比特模式进行排序(#9178)。

    • jax.numpy.unique() 现在像 NumPy 1.21 及更新版本中的 np.unique 一样处理 NaN 值:去重输出中最多会出现一个 NaN 值(#9184)。

  • 错误修复

    • host_callback 现在支持 ad_checkpoint.checkpoint(#8907)。

  • 新功能

    • 添加 jax.block_until_ready ({jax-issue}`#8941)

    • 添加了一个新的调试标志/环境变量 JAX_DUMP_IR_TO=/path。如果设置,JAX 将为每个计算生成的 MHLO/HLO IR 转储到给定路径下的一个文件。

    • jax.ensure_compile_time_eval 添加到公共 API(#7987)。

    • jax2tf 现在支持 jax2tf_associative_scan_reductions 标志,以更改对关联约简(例如,jnp.cumsum)的降低,使其行为类似于 CPU 和 GPU 上的 JAX(使用关联扫描)。有关详细信息,请参阅 jax2tf README(#9189)。

jaxlib 0.1.75 (2021 年 12 月 8 日)#

  • 新功能

    • 支持 python 3.10。

jax 0.2.26 (2021 年 12 月 8 日)#

  • GitHub commits.

  • 错误修复

    • jax.ops.segment_sum 的越界索引将使用 FILL_OR_DROP 语义进行处理,如文档所述。这主要影响反向模式导数,其中对应于越界索引的梯度现在将返回 0。(#8634)。

    • jax2tf 将强制转换后的代码在 jax.jit 的代码片段下使用 XLA,例如,大多数 jax.numpy 函数(#7839)。

jaxlib 0.1.74 (2021 年 11 月 17 日)#

  • 启用了 GPU 之间的点对点(peer-to-peer)复制。之前,GPU 复制会通过主机进行中转,这通常较慢。

  • 添加了供 JAX 使用的实验性 MLIR Python 绑定。

jax 0.2.25 (2021 年 11 月 10 日)#

  • GitHub commits.

  • 新功能

    • (实验性)jax.distributed.initialize 暴露了多主机 GPU 后端。

    • jax.random.permutation 支持新的 independent 关键字参数(#8430)。

  • 重大更改

    • jax.experimental.stax 移动到 jax.example_libraries.stax

    • jax.experimental.optimizers 移动到 jax.example_libraries.optimizers

  • 新功能

    • 添加了 jax.lax.linalg.qdwh

jax 0.2.24 (2021 年 10 月 19 日)#

  • GitHub commits.

  • 新功能

    • jax.random.choicejax.random.permutation 现在支持多维数组和可选的 axis 参数(#8158)。

  • 重大更改

    • jax.numpy.takejax.numpy.take_along_axis 现在要求类数组输入(参见 #7737)。

jaxlib 0.1.73 (2021 年 10 月 18 日)#

  • Jaxlib GPU cuda11 wheel 现在支持多个 cuDNN 版本。

    • cuDNN 8.2 或更新版本。建议使用 cuDNN 8.2 wheel(如果您的 cuDNN 安装版本足够新),因为它支持其他功能。

    • cuDNN 8.0.5 或更新版本。

  • 重大更改

    • GPU jaxlib 的安装命令如下:

      pip install --upgrade pip
      
      # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
      pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
      pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
      pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      

jax 0.2.22 (2021 年 10 月 12 日)#

  • GitHub commits.

  • 重大更改

    • jax.pmap 的静态参数现在必须是可哈希的。

      不可哈希的静态参数长期以来在 jax.jit 中被禁止,但在 jax.pmap 中仍然允许;jax.pmap 使用对象标识符比较不可哈希的静态参数。

      此行为是一个陷阱,因为使用对象标识符比较参数会导致每次对象标识符更改时都重新编译。相反,我们现在禁止不可哈希的参数:如果 jax.pmap 的用户想要通过对象标识符比较静态参数,他们可以为其对象定义具有该行为的 __hash____eq__ 方法,或者将他们的对象包装在一个具有对象标识符语义的运算的对象中。另一个选择是使用 functools.partial 将不可哈希的静态参数封装到函数对象中。

    • jax.util.partial 是一个意外的导出,现已移除。请使用 Python 标准库中的 functools.partial 代替。

  • 弃用

    • 函数 jax.ops.index_updatejax.ops.index_add 等已被弃用,将在未来的 JAX 版本中移除。请使用 JAX 数组上的 .at 属性 代替,例如 x.at[idx].set(y)。目前,这些函数会产生一个 DeprecationWarning

  • 新功能

    • 使用 jaxlib 0.1.72 或更新版本时,优化的 C++ 代码路径可以提高 pmap 的调度时间。该功能可以通过 --experimental_cpp_pmap 标志(或 JAX_CPP_PMAP 环境变量)禁用。

    • jax.numpy.unique 现在支持可选的 fill_value 参数(#8121)。

jaxlib 0.1.72 (2021 年 10 月 12 日)#

  • 重大更改

    • 已停止支持 CUDA 10.2 和 CUDA 10.1。Jaxlib 现在支持 CUDA 11.1+。

  • 错误修复

    • 修复了 https://github.com/jax-ml/jax/issues/7461 中的 bug,该 bug 由于 XLA 编译器内部的缓冲区别名错误导致在所有平台上产生错误输出。

jax 0.2.21 (2021 年 9 月 23 日)#

  • GitHub commits.

  • 重大更改

    • jax.api 已移除。原本可作为 jax.api.* 访问的函数是 jax.* 中函数的别名;请使用 jax.* 中的函数。

    • jax.partialjax.lax.partial 是意外的导出,现已移除。请使用 Python 标准库中的 functools.partial 代替。

    • 布尔标量索引现在会引发 TypeError;之前这会静默返回错误的结果(#7925)。

    • 许多 jax.numpy 函数现在要求类数组输入,并且在传入列表时会报错(#7747 #7802 #7907)。请参阅 #7737 讨论此更改背后的原因。

    • 在转换(如 jax.jit)内部时,jax.numpy.array 始终会将它生成的数组暂存到跟踪的计算中。之前,即使在 jax.jit 装饰器下,jax.numpy.array 有时也会生成一个设备上的数组。此更改可能会破坏使用 JAX 数组进行形状或索引计算的代码,这些计算必须是静态已知的;解决方法是改用经典的 NumPy 数组执行此类计算。

    • jnp.ndarray 现在是 JAX 数组的真实基类。特别是,这意味着对于标准的 NumPy 数组 xisinstance(x, jnp.ndarray) 现在将返回 False#7927)。

  • 新功能

jax 0.2.20 (2021 年 9 月 2 日)#

  • GitHub commits.

  • 重大更改

    • jnp.poly* 函数现在要求类数组输入(#7732)。

    • jnp.unique 和其他类集合操作现在要求类数组输入(#7662)。

jaxlib 0.1.71 (2021 年 9 月 1 日)#

  • 重大更改

    • 已停止支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。

jax 0.2.19 (2021 年 8 月 12 日)#

  • GitHub commits.

  • 重大更改

    • 根据 弃用策略,已停止支持 NumPy 1.17。请升级到支持的 NumPy 版本。

    • 在 JAX 数组的许多运算符的实现周围添加了 jit 装饰器。这加快了 + 等常用运算符的调度时间。

      此更改对大多数用户来说应该基本透明。但是,有一个已知的行为变化,即大型整数常量现在可能在直接传递给 JAX 运算符时(例如,x + 2**40)引发错误。解决方法是将常量强制转换为显式类型(例如,np.float64(2**40))。

  • 新功能

    • 改进了 jax2tf 对形状多态的支持,用于需要使用尺寸大小的数组计算的操作,例如 jnp.mean。(#7317)。

  • 错误修复

    • 上一个版本中一些泄漏的跟踪错误(#7613)。

jaxlib 0.1.70 (2021 年 8 月 9 日)#

  • 重大更改

    • 根据 弃用策略,已停止支持 Python 3.6。请升级到支持的 Python 版本。

    • 根据 弃用策略,已停止支持 NumPy 1.17。请升级到支持的 NumPy 版本。

    • host_callback 机制现在使用每个本地设备的单个线程来调用 Python 回调。之前只有一个线程用于所有设备。这意味着回调现在可能会交错调用。来自一个设备的回调仍然会按顺序调用。

jax 0.2.18 (2021 年 7 月 21 日)#

  • GitHub commits.

  • 重大更改

    • 根据 弃用策略,已停止支持 Python 3.6。请升级到支持的 Python 版本。

    • 最低 jaxlib 版本现在是 0.1.69。

    • 已移除 jax.dlpack.from_dlpack()backend 参数。

  • 新功能

  • 错误修复

    • 收紧了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会与无效的 axis 值或空的归约维度一起使用。(#7196)。

jaxlib 0.1.69 (2021 年 7 月 9 日)#

  • 修复了 TFRT CPU 后端导致结果错误的 bug。

jax 0.2.17 (2021 年 7 月 9 日)#

  • GitHub commits.

  • 错误修复

    • 默认使用旧的“stream_executor”CPU 运行时,以使 jaxlib <= 0.1.68 正常工作,以解决 #7229,该问题导致 CPU 上由于并发问题而产生错误输出。

  • 新功能

jax 0.2.16 (2021 年 6 月 23 日)#

jax 0.2.15 (2021 年 6 月 23 日)#

  • GitHub commits.

  • 新功能

    • #7042 开启了 TFRT CPU 后端,在 CPU 上实现了显著的调度性能改进。

    • jax2tf.convert() 支持布尔值的不等式和 min/max(#6956)。

    • 新的 SciPy 函数 jax.scipy.special.lpmn_values()

  • 重大更改

    • 根据 弃用策略,已停止支持 NumPy 1.16。请升级到支持的 NumPy 版本。

  • 错误修复

    • 修复了阻止 JAX 到 TF 再回到 JAX 的往返的 bug:jax2tf.call_tf(jax2tf.convert)#6947)。

jaxlib 0.1.68 (2021 年 6 月 23 日)#

  • 错误修复

    • 修复了 TFRT CPU 后端中将 TPU buffer 传输到 CPU 时导致 NaN 的 bug。

jax 0.2.14 (2021 年 6 月 10 日)#

  • GitHub commits.

  • 新功能

    • jax2tf.convert() 现在支持 pjitsharded_jit

    • 一个新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯。

    • 在足够新的 IPython 版本中,默认启用了使用 __tracebackhide__ 的新回溯过滤模式。

    • jax2tf.convert() 支持形状多态,即使未知维度用于算术运算(例如,jnp.reshape(-1))(#6827)。

    • jax2tf.convert() 在 TF op 中生成具有位置信息的自定义属性。jax2tf 转换后 XLA 生成的代码与 JAX/XLA 具有相同的位置信息。

    • 新的 SciPy 函数 jax.scipy.special.lpmn()

  • 错误修复

    • jax2tf.convert() 现在确保它对 Python 标量和选择 32 位 vs 64 位计算使用与 JAX 相同的类型规则(#6883)。

    • jax2tf.convert() 现在正确地将 enable_xla 转换参数限定在 just-in-time 转换期间生效(#6720)。

    • jax2tf.convert() 现在使用 XlaDot TensorFlow op 来转换 lax.dot_general,以提高与 JAX 数值精度的保真度(#6717)。

    • jax2tf.convert() 现在支持复数的不等比较和 min/max(#6892)。

jaxlib 0.1.67 (2021 年 5 月 17 日)#

jaxlib 0.1.66 (2021 年 5 月 11 日)#

  • 新功能

    • 现在所有 CUDA 11 版本 11.1 或更高版本都支持 CUDA 11.1 wheel。

      Nvidia 现在承诺从 CUDA 11.1 开始,CUDA 次要版本之间兼容。这意味着 JAX 可以发布一个 CUDA 11.1 wheel,兼容 CUDA 11.2 和 11.3。

      不再为 CUDA 11.2(或更高版本)单独发布 jaxlib;请为这些版本使用 CUDA 11.1 wheel (cuda111)。

    • Jaxlib 现在在 CUDA wheel 中捆绑了 libdevice.10.bc。不再需要将 JAX 指向 CUDA 安装来查找此文件。

    • jit() 实现中添加了对静态关键字参数的自动支持。

    • 添加了对预转换异常跟踪的支持。

    • 初步支持从 jit() 转换后的计算中修剪未使用的参数。修剪仍在进行中。

    • 改进了 PyTreeDef 对象的字符串表示。

    • 添加了对 XLA 的可变 ReduceWindow 的支持。

  • 错误修复

    • 修复了当大量参数传递给计算时,远程云 TPU 支持中的一个 bug。

    • 修复了一个 bug,该 bug 导致 jit() 转换的函数不会触发 JAX 的垃圾回收。

jax 0.2.13 (2021 年 5 月 3 日)#

  • GitHub commits.

  • 新功能

    • 与 jaxlib 0.1.66 结合使用时,jax.jit() 现在支持静态关键字参数。添加了一个新的 static_argnames 选项来指定关键字参数为静态。

    • jax.nonzero() 有一个新的可选 size 参数,允许它在 jit 中使用(#6501

    • jax.numpy.unique() 现在支持 axis 参数(#6532)。

    • jax.experimental.host_callback.call() 现在支持 pjit.pjit#6569)。

    • 添加了 jax.scipy.linalg.eigh_tridiagonal(),它计算三对角矩阵的特征值。目前仅支持特征值。

    • 异常中过滤和未过滤的堆栈跟踪的顺序已更改。从 JAX 转换的代码中抛出的异常附加的跟踪现在是过滤后的,一个 UnfilteredStackTrace 异常作为过滤后异常的 __cause__ 包含原始跟踪。过滤后的堆栈跟踪现在也支持 Python 3.6。

    • 如果异常是由反向模式自动微分转换的代码引起的,JAX 现在会尝试将一个 JaxStackTraceBeforeTransformation 对象作为异常的 __cause__ 附加,该对象包含创建前向传递中原始操作的堆栈跟踪。需要 jaxlib 0.1.66。

  • 重大更改

    • 以下函数名已更改。仍保留别名,因此不应破坏现有代码,但别名最终将被删除,请更改您的代码。

    • 类似地,local_devices() 的参数已从 host_id 重命名为 process_index

    • jax.jit() 的参数(除了函数本身)现在被标记为仅关键字参数。此更改是为了防止在向 jit 添加参数时发生意外中断。

  • 错误修复

    • jax2tf.convert() 现在可以在存在整数输入函数的梯度时工作(#6360)。

    • 修复了在与捕获的 tf.Variable 一起使用 jax2tf.call_tf() 时发生的断言失败(#6572)。

jaxlib 0.1.65 (2021 年 4 月 7 日)#

jax 0.2.12 (2021 年 4 月 1 日)#

  • GitHub commits.

  • 新功能

  • 重大更改

    • 最低 jaxlib 版本现为 0.1.64。

    • 一些分析 API 的名称已更改。仍保留别名,因此不应破坏现有代码,但别名最终将被删除,请更改您的代码。

    • Omnistaging 已无法禁用。有关更多信息,请参阅 omnistaging

    • 大于最大 int64 值的 Python 整数现在将在所有情况下导致溢出,而不是在某些情况下被静默转换为 uint64#6047)。

    • 在 X64 模式之外,超出 int32 可表示范围的 Python 整数现在将导致 OverflowError,而不是静默截断其值。

  • 错误修复

    • host_callback 现在支持参数和结果中的空数组(#6262)。

    • jax.random.randint() 现在会裁剪而不是包装越界限制,并且现在可以生成指定 dtype 完整范围内的整数(#5868

jax 0.2.11 (2021 年 3 月 23 日)#

  • GitHub commits.

  • 新功能

    • jax.enable_checksjax.check_tracer_leaksjax.debug_nansjax.debug_infsjax.log_compiles 添加了上下文管理器(#6112)。

    • 添加了 jnp.delete#6085

  • 错误修复

    • jax.flatten_util.ravel_pytree 已泛化为处理整数 dtype(#6136)。

    • 修复了处理某些常量(如 enum.IntEnums)的 bug(#6129

    • 修复了不完整 beta 函数的批处理问题(#6145

    • 修复了跟踪期间的 H2D 传输(#6014

    • 在将一些大的 Python 整数转换为浮点数时,避免了 OverflowErrors(#6165

  • 重大更改

    • 最低 jaxlib 版本现为 0.1.62。

jaxlib 0.1.64 (2021 年 3 月 18 日)#

jaxlib 0.1.63 (2021 年 3 月 17 日)#

jax 0.2.10 (2021 年 3 月 5 日)#

  • GitHub commits.

  • 新功能

    • jax.scipy.stats.chi2() 现已作为具有 logpdf 和 pdf 方法的分布可用。

    • jax.scipy.stats.betabinom() 现已作为具有 logpmf 和 pmf 方法的分布可用。

    • 添加了 jax.experimental.jax2tf.call_tf() 以从 JAX 调用 TensorFlow 函数(#5627)和 README)。

    • 扩展了 lax.pad 的批处理规则以支持 padding 值的批处理。

  • 错误修复

    • jax.numpy.take() 正确处理负索引(#5768

  • 重大更改

    • JAX 的提升规则已调整,以使提升更加一致且对 JIT 不变。特别是,二元运算现在可以在适当的情况下产生弱类型值。此更改的主要用户可见效果是某些运算产生的输出精度与之前不同;例如,表达式 jnp.bfloat16(1) + 0.1 * jnp.arange(10) 以前返回一个 float64 数组,现在返回一个 bfloat16 数组。JAX 的类型提升行为在 类型提升语义 中进行描述。

    • jax.numpy.linspace() 现在计算整数值的 floor(即向 -inf 舍入而不是向 0 舍入)。此更改是为了匹配 NumPy 1.20.0。

    • jax.numpy.i0() 不再接受复数。以前,该函数计算复数参数的绝对值。此更改是为了匹配 NumPy 1.20.0 的语义。

    • 几个 jax.numpy 函数不再接受元组或列表作为数组参数:jax.numpy.pad(), :funcjax.numpy.ravel, jax.numpy.repeat(), jax.numpy.reshape()。总的来说,jax.numpy 函数应该与标量或数组参数一起使用。

jaxlib 0.1.62 (2021 年 3 月 9 日)#

  • 新功能

    • jaxlib wheel 现在默认构建为需要 x86-64 机器上的 AVX 指令。如果您想在不支持 AVX 的机器上使用 JAX,可以使用 build.py--target_cpu_features 标志从源构建 jaxlib。--target_cpu_features 也取代了 --enable_march_native

jaxlib 0.1.61 (2021 年 2 月 12 日)#

jaxlib 0.1.60 (2021 年 2 月 3 日)#

  • 错误修复

    • 修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏。该内存泄漏存在于 jaxlib 版本 0.1.58 和 0.1.59 中。

    • boolint8uint8 现在被认为是安全地转换为 bfloat16 NumPy 扩展类型的。

jax 0.2.9 (2021 年 1 月 26 日)#

  • GitHub commits.

  • 新功能

  • 重大更改

    • jax.ops.segment_sum() 现在会丢弃超出范围的 segment IDs,而不是将它们包装到 segment ID 空间中。这是出于性能原因。

jaxlib 0.1.59 (2021 年 1 月 15 日)#

jax 0.2.8 (2021 年 1 月 12 日)#

  • GitHub commits.

  • 新功能

    • 为与高阶自定义导数函数一起使用添加了 jax.closure_convert()。(#5244

    • 添加了 jax.experimental.host_callback.call() 以调用自定义 Python 函数并在主机上将结果返回到设备计算。(#5243

  • 错误修复

    • jax.numpy.arccosh 现在返回与复数输入的 numpy.arccosh 相同的分支(#5156

    • host_callback.id_tap 现在也可以用于 jax.pmapid_tapid_print 有一个可选参数,用于请求将值被 tap 的设备作为关键字参数传递给 tap 函数(#5182)。

  • 重大更改

    • jax.numpy.pad 现在接受关键字参数。已删除位置参数 constant_values。此外,传递不支持的关键字参数会引发错误。

    • jax.experimental.host_callback.id_tap() 的更改(#5243

      • 删除了对 jax.experimental.host_callback.id_tap()kwargs 的支持。(此支持已弃用几个月。)

      • 更改了 jax.experimental.host_callback.id_print() 的元组打印方式,使用 ‘(‘ 而不是 ‘[‘。

      • 在 JVP 存在的情况下,更改了 jax.experimental.host_callback.id_print() 以打印原始值和切线值对。以前,原始值和切线值有两个单独的打印操作。

      • host_callback.outfeed_receiver 已删除(它不是必需的,并且几个月前已被弃用)。

  • 新功能

    • 用于调试 inf 的新标志,类似于 NaN 的标志(#5224)。

jax 0.2.7 (2020 年 12 月 4 日)#

  • GitHub commits.

  • 新功能

    • 添加 jax.device_put_replicated

    • jax.experimental.sharded_jit 添加了多主机支持

    • 添加了对 jax.numpy.linalg.eig 计算的特征值进行微分的支持

    • 添加了对 Windows 平台构建的支持

    • jax.pmap 添加了 in_axesout_axes 的通用支持

    • jax.numpy.linalg.slogdet 添加了复数支持

  • 错误修复

    • 修复了在零点处 jax.numpy.sinc 的高阶(高于二阶)导数

    • 修复了转置规则中与符号零相关的难以触及的一些 bug

  • 重大更改

    • jax.experimental.optix 已删除,取而代之的是独立的 optax Python 包。

    • 使用非元组序列对 JAX 数组进行索引现在会引发 TypeError。此类型的索引在 Numpy 中自 v1.16 起已弃用,在 JAX 中自 v0.2.4 起已弃用。请参阅 #4564

jax 0.2.6 (2020 年 11 月 18 日)#

  • GitHub commits.

  • 新功能

    • 为 jax.experimental.jax2tf 转换器添加了对形状多态跟踪的支持。请参阅 README.md

  • 破坏性更改清理

    • 对 jax.jit 和 xla_computation 的不可哈希静态参数引发错误。请参阅 cb48f42

    • 提高类型提升行为的一致性(#4744

      • 将复杂的 Python 标量添加到 JAX 浮点数会尊重 JAX 浮点数的精度。例如,jnp.float32(1) + 1j 现在返回 complex64,而以前返回 complex128

      • 涉及 uint64、有符号整数和第三种类型的 3 个或更多项的类型提升结果现在与参数顺序无关。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) 都返回 float16,而以前第一个返回 float64,第二个返回 float16

    • (未记录的)jax.lax_linalg 线性代数模块的内容现在公开为 jax.lax.linalg

    • jax.random.PRNGKey 在 JIT 编译内外现在产生相同的结果(#4877)。这需要更改在某些特定情况下的结果

      • 使用 jax_enable_x64=False 时,负数种子(作为 Python 整数传递)在 JIT 模式外现在返回不同的结果。例如,jax.random.PRNGKey(-1) 以前返回 [4294967295, 4294967295],现在返回 [0, 4294967295]。这与 JIT 中的行为匹配。

      • 在 JIT 外,超出 int64 可表示范围的种子现在会导致 OverflowError 而不是 TypeError。这与 JIT 中的行为匹配。

      要恢复 JIT 外 jax_enable_x64=False 时以前为负整数返回的键,可以使用

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
      
    • DeviceArray 现在在尝试访问已被删除的值时引发 RuntimeError 而不是 ValueError

jaxlib 0.1.58 (2021 年 1 月 12 日左右)#

  • 修复了一个 bug,该 bug 导致 JAX 有时会返回特定于平台的类型(例如 np.cint)而不是标准类型(例如 np.int32)。(#4903)

  • 修复了在常量折叠某些 int16 操作时崩溃的问题。(#4971)

  • pytree.flatten() 中添加了一个 is_leaf predicate。

jaxlib 0.1.57 (2020 年 11 月 12 日)#

  • 修复了 GPU wheel 中的 manylinux2010 合规性问题。

  • 将 CPU FFT 实现从 Eigen 切换到 PocketFFT。

  • 修复了一个 bug,该 bug 导致 bfloat16 值的哈希未正确初始化并可能改变(#4651)。

  • 添加了在将数组传递给 DLPack 时保留所有权的支持(#4636)。

  • 修复了大小大于 128 但不是 128 倍数的批处理三角解的 bug。

  • 修复了在多个 GPU 上并发执行 FFT 时出现的 bug(#3518)。

  • 修复了分析器中工具缺失的 bug(#4427)。

  • 放弃了对 CUDA 10.0 的支持。

jax 0.2.5 (2020 年 10 月 27 日)#

jax 0.2.4 (2020 年 10 月 19 日)#

  • GitHub commits.

  • 改进

    • 为 jax.experimental.host_callback 添加了对 remat 的支持。请参阅 #4608

  • 弃用

    • 遵循 Numpy 中的类似弃用,使用非元组序列进行索引现在已被弃用。在未来的版本中,这将导致 TypeError。请参阅 #4564

jaxlib 0.1.56 (2020 年 10 月 14 日)#

jax 0.2.3 (2020 年 10 月 14 日)#

  • GitHub commits.

  • 之所以这么快发布另一个版本,是因为我们需要暂时回滚一个新的 jit 快速路径,同时我们正在调查性能下降问题。

jax 0.2.2 (2020 年 10 月 13 日)#

jax 0.2.1 (2020 年 10 月 6 日)#

  • GitHub commits.

  • 改进

    • 作为 omnistaging 的一个好处,即使 jax.experimental.host_callback.id_print()/jax.experimental.host_callback.id_tap() 的结果未在计算中使用,host_callback 函数也会(按程序顺序)执行。

jax (0.2.0) (2020 年 9 月 23 日)#

jax (0.1.77) (2020 年 9 月 15 日)#

  • 重大更改

    • jax.experimental.host_callback.id_tap() 的新简化接口(#4101)

jaxlib 0.1.55 (2020 年 9 月 8 日)#

  • 更新 XLA

    • 修复了 DLPackManagedTensorToBuffer 中的 bug(#4196)

jax 0.1.76 (2020 年 9 月 8 日)#

jax 0.1.75 (2020 年 7 月 30 日)#

  • GitHub commits.

  • Bug 修复

    • 使 jnp.abs() 可用于无符号输入(#3914)

  • 改进

    • “Omnistaging”行为通过标志启用,默认禁用(#3370)

jax 0.1.74 (2020 年 7 月 29 日)#

  • GitHub commits.

  • 新功能

    • BFGS(#3101)

    • TPU 支持半精度算术(#3878)

  • Bug 修复

    • 防止一些意外的 dtype 警告(#3874)

    • 修复了自定义导数中的多线程 bug(#3845,#3869)

  • 改进

    • 更快的 searchsorted 实现(#3873)

    • 改进了 jax.numpy 排序算法的测试覆盖率(#3836)

jaxlib 0.1.52 (2020 年 7 月 22 日)#

  • 更新 XLA。

jax 0.1.73 (2020 年 7 月 22 日)#

  • GitHub commits.

  • 最低 jaxlib 版本现为 0.1.51。

  • 新功能

    • jax.image.resize。(#3703)

    • hfft 和 ihfft(#3664)

    • jax.numpy.intersect1d(#3726)

    • jax.numpy.lexsort(#3812)

    • lax.scanscan 原语支持 unroll 参数,用于在降低到 XLA 时进行循环展开(#3738)。

  • Bug 修复

    • 修复了 reduction 重复轴的错误(#3618)

    • 修复了 lax.pad 对于尺寸为 0 的输入维度的形状规则。(#3608)

    • 使 psum 转置处理零余切(#3653)

    • 修复了在对大小为 0 的轴进行 reduce-prod 的 JVP 时出现的形状错误。(#3729)

    • 支持对 jax.lax.all_to_all 进行微分(#3733)

    • 解决了 jax.scipy.special.zeta 中的 nan 问题(#3777)

  • 改进

    • 对 jax2tf 进行大量改进

    • 使用单个传递的可变规约重新实现 argmin/argmax。(#3611)

    • 默认启用 XLA SPMD 分区。(#3151)

    • 添加了对 0d 转置卷积的支持(#3643)

    • 使 LU 梯度适用于低秩矩阵(#3610)

    • 支持 multiple_results 和 jet 中的自定义 JVP(#3657)

    • 将 reduce-window 填充泛化为支持 (lo, hi) 对。(#3728)

    • 在 CPU 和 GPU 上实现复数卷积。(#3735)

    • 使 jnp.take 在处理空数组的空切片时工作。(#3751)

    • 放宽了 dot_general 的维度顺序规则。(#3778)

    • 为 GPU 启用缓冲区捐赠。(#3800)

    • 将基值膨胀和窗口膨胀添加到 reduce window op…(#3803)

jaxlib 0.1.51 (2020 年 7 月 2 日)#

  • 更新 XLA。

  • 为 host_callback 添加了新的运行时支持。

jax 0.1.72 (2020 年 6 月 28 日)#

  • GitHub commits.

  • 错误修复

    • 修复了上一版本中引入的 odeint bug,请参阅 #3587

jax 0.1.71 (2020 年 6 月 25 日)#

  • GitHub commits.

  • 最低 jaxlib 版本现为 0.1.48。

  • 错误修复

    • 允许 jax.experimental.ode.odeint 动态函数在区分关于时间时关闭具有时间依赖性的值 #3562

jaxlib 0.1.50 (2020 年 6 月 25 日)#

  • 添加了对 CUDA 11.0 的支持。

  • 放弃了对 CUDA 9.2 的支持(我们只维护对最后四个 CUDA 版本的支持)。

  • 更新 XLA。

jaxlib 0.1.49 (2020 年 6 月 19 日)#

jaxlib 0.1.48 (2020 年 6 月 12 日)#

  • 新功能

    • 添加了对快速堆栈跟踪收集的支持。

    • 添加了对设备上堆分析的初步支持。

    • bfloat16 类型实现了 np.nextafter

    • CPU 和 GPU 上 FFT 的 Complex128 支持。

  • 错误修复

    • 提高了 GPU 上 float64 tanh 的精度。

    • GPU 上的 float64 散点图速度更快。

    • CPU 上的复数矩阵乘法速度应该快得多。

    • CPU 上的稳定排序现在应该是真正稳定的。

    • CPU 后端的并发 bug 修复。

jax 0.1.70 (2020 年 6 月 8 日)#

  • GitHub commits.

  • 新功能

    • lax.switch 引入了带有多个分支的索引条件,并对 cond 原语进行了泛化 #3318

jax 0.1.69 (2020 年 6 月 3 日)#

jax 0.1.68 (2020 年 5 月 21 日)#

  • GitHub commits.

  • 新功能

    • lax.cond() 支持单操作数形式,该形式被用作两个分支的参数 #2993

  • 值得注意的更改

    • jax.experimental.host_callback.id_tap() 原语的 transforms 关键字的格式已更改 #3132

jax 0.1.67 (2020 年 5 月 12 日)#

  • GitHub commits.

  • 新功能

    • 使用 axis_index_groups 对 pmapped 轴的子集进行规约的支持 #2382

    • 实验性支持从编译代码打印和调用主机端 Python 函数。请参阅 id_print 和 id_tap#3006)。

  • 值得注意的更改

    • jax.numpy 导出的名称的可访问性已得到加强。这可能会破坏使用了之前意外导出的名称的代码。

jaxlib 0.1.47 (2020 年 5 月 8 日)#

  • 修复了 outfeed 崩溃。

jax 0.1.66 (2020 年 5 月 5 日)#

jaxlib 0.1.46 (2020 年 5 月 5 日)#

  • 修复了在存在不同型号的多个 GPU 时,JAX 只编译适用于第一个 GPU 的程序的崩溃问题(#432)。

  • 修复了在使用操作系统或虚拟机管理程序禁用了 AVX512 指令时,由使用 AVX512 指令引起的非法指令崩溃。(#2906)

jax 0.1.65 (2020 年 4 月 30 日)#

  • GitHub commits.

  • 新功能

    • 奇数矩阵行列式的微分 #2809

  • 错误修复

    • 修复了具有时间依赖性动态的 ODEs 的 odeint() 微分,以及添加了 ODE CI 测试 #2817

    • 修复了 lax_linalg.qr() 的微分 #2867

jaxlib 0.1.45 (2020 年 4 月 21 日)#

  • 修复了段错误:#2755

  • is_stable 选项从 Sort HLO 传递到 Python。

jax 0.1.64 (2020 年 4 月 21 日)#

jaxlib 0.1.44 (2020 年 4 月 16 日)#

  • 修复了一个 bug,该 bug 导致如果存在多个不同型号的 GPU,JAX 只会编译适用于第一个 GPU 的程序。

  • batch_group_count 卷积的 bug 修复。

  • 为更多 GPU 版本添加了预编译的 SASS,以避免启动 PTX 编译挂起。

jax 0.1.63 (2020 年 4 月 12 日)#

  • GitHub commits.

  • 添加了 jax.custom_jvpjax.custom_vjp#2026),请参阅 教程 notebook。弃用了 jax.custom_transforms 并从文档中删除了它(尽管它仍然有效)。

  • 添加了 scipy.sparse.linalg.cg #2566

  • 更改了 Tracers 的打印方式,以显示更有用的调试信息 #2591

  • 使 jax.numpy.isclose 正确处理 naninf #2501

  • jax.experimental.jet 添加了几个新的规则 #2537

  • 修复了 jax.experimental.stax.BatchNorm 在未提供 scale/center 时的问题。

  • 修复了 jax.numpy.einsum 中一些缺失的广播情况 #2512

  • 使用并行前缀扫描实现了 jax.numpy.cumsumjax.numpy.cumprod #2596,并使 reduce_prod 可任意阶微分 #2597

  • batch_group_count 添加到 conv_general_dilated #2635

  • test_util.check_grads 添加了文档字符串 #2656

  • 添加了 callback_transform #2665

  • 实现了 rollaxisconvolve/correlate 1d & 2d、copysigntruncroots 以及 quantile/percentile 插值选项。

jaxlib 0.1.43 (2020 年 3 月 31 日)#

  • 修复了 GPU 上 Resnet-50 的性能回归。

jax 0.1.62 (2020 年 3 月 21 日)#

  • GitHub commits.

  • JAX 已放弃对 Python 3.5 的支持。请升级到 Python 3.6 或更高版本。

  • 删除了内部函数 lax._safe_mul,该函数实现了 0. * nan == 0. 的约定。此更改意味着某些程序在微分时会产生 nan,而以前会产生正确的值,但它确保了其他程序不会产生静默不正确的结果,而是会产生 nan。有关详细信息,请参阅 #2447 和 #1052。

  • 添加了一个 all_gather 并行便利函数。

  • 核心代码中的更多类型注解。

jaxlib 0.1.42 (2020 年 3 月 19 日)#

  • jaxlib 0.1.41 由于 API 不兼容而破坏了云 TPU 支持。此版本已重新修复。

  • JAX 已放弃对 Python 3.5 的支持。请升级到 Python 3.6 或更高版本。

jax 0.1.61 (2020 年 3 月 17 日)#

  • GitHub commits.

  • 修复了 Python 3.5 支持。这将是最后一个支持 Python 3.5 的 JAX 或 jaxlib 版本。

jax 0.1.60 (2020 年 3 月 17 日)#

  • GitHub commits.

  • 新功能

    • jax.pmap() 具有 static_broadcast_argnums 参数,该参数允许用户指定应视为编译时常量并应广播到所有设备的参数。它的工作方式类似于 jax.jit() 中的 static_argnums

    • 改进了当 Tracers 错误地保存在全局状态时出现的错误消息。

    • 添加了 jax.nn.one_hot() 实用函数。

    • 添加了 jax.experimental.jet 以实现指数级更快的二阶及更高阶自动微分。

    • jax.lax.broadcast_in_dim() 的参数添加了更多正确性检查。

  • 最低 jaxlib 版本现为 0.1.41。

jaxlib 0.1.40 (2020 年 3 月 4 日)#

  • 在 Jaxlib 中添加了对 TensorFlow 分析器的实验性支持,该支持允许从 TensorBoard 跟踪 CPU 和 GPU 计算。

  • 包括了多主机 GPU 计算(通过 NCCL 通信)的原型支持。

  • 提高了 GPU 上 NCCL 集合通信的性能。

  • 添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。

  • 支持在 XLA 编译时已知的设备分配。

jax 0.1.59 (2020 年 2 月 11 日)#

  • GitHub commits.

  • 重大更改

    • 最低 jaxlib 版本现为 0.1.38。

    • 简化了 Jaxpr,移除了 Jaxpr.freevarsJaxpr.bound_subjaxprs。调用原语(xla_callxla_pmapsharded_callremat_call)获得了一个新参数 call_jaxpr,其中包含一个完全闭合的(无 constvars)jaxpr。此外,在原语中添加了一个新字段 call_primitive

  • 新功能

    • 反向模式自动微分(例如 grad)对 lax.cond 的微分,使其现在可以双向微分(#2091

    • JAX 现在支持 DLPack,这允许与 PyTorch 等其他库以零拷贝方式共享 CPU 和 GPU 数组。

    • JAX GPU DeviceArrays 现在支持 __cuda_array_interface__,这是另一种与 CuPy 和 Numba 等其他库共享 GPU 数组的零拷贝协议。

    • JAX CPU 设备缓冲区现在实现了 Python 缓冲区协议,这允许 JAX 和 NumPy 之间进行零拷贝缓冲区共享。

    • 添加了 JAX_SKIP_SLOW_TESTS 环境变量以跳过已知较慢的测试。

jaxlib 0.1.39 (2020 年 2 月 11 日)#

  • 更新 XLA。

jaxlib 0.1.38 (2020 年 1 月 29 日)#

  • 不再支持 CUDA 9.0。

  • 默认情况下现在构建 CUDA 10.2 wheel。

jax 0.1.58 (2020 年 1 月 28 日)#

值得注意的 bug 修复#

  • 通过升级到 Python 3,JAX 不再依赖于 fastcache,这应该有助于安装。