变更日志#
最好 在此处 查看。 有关特定于实验性 Pallas API 的更改,请参阅 Pallas 变更日志。
JAX 遵循基于努力的版本控制; 有关此内容和 JAX 的 API 兼容性策略的讨论,请参阅 API 兼容性。 有关 Python 和 NumPy 版本支持策略,请参阅 Python 和 NumPy 版本支持策略。
未发布#
新功能
添加了
jax.lax.axis_size()
,它返回给定名称的映射轴的大小。
JAX 0.6.0 (2025 年 4 月 16 日)#
重大更改
jax.numpy.array()
不再接受None
。 此行为自 2023 年 11 月起已弃用,现已删除。删除了
config.jax_data_dependent_tracing_fallback
配置选项,该选项在 v0.4.36 中临时添加,以允许用户选择退出新的“无堆栈”跟踪机制。删除了
config.jax_eager_pmap
配置选项。禁止在应用后续包装器后,在
jax.jit
的结果上调用lower
和trace
AOT API。 以前这可以工作,但会默默地忽略包装器。 解决方法是在包装器中最后应用jax.jit
,对于jax.pmap
也是如此。 请参阅 #27873。jax
的cuda12_pip
extra 已删除; 请改用pip install jax[cuda12]
。
更改
最低 CuDNN 版本为 v9.8。
JAX 现在使用 CUDA 12.8 构建。 所有 CUDA 12.1 或更新版本仍然受支持。
JAX 包 extras 现在已更新为使用短划线而不是下划线,以与 PEP 685 对齐。 例如,如果您之前使用
pip install jax[cuda12_local]
安装 JAX,请运行pip install jax[cuda12-local]
代替。jax.jit()
现在要求通过位置传递fun
,并通过关键字传递其他参数。 否则,将在 v0.6.X 中导致 DeprecationWarning,并在 v0.7.X 中开始导致错误。
弃用
为 CPU 和 GPU 设备实现了使用 XLA 的 FFI 的主机回调处理程序,并删除了使用 XLA 自定义调用的现有 CPU/GPU 处理程序。
jax.lib.xla_extension
中的所有 API 现在都已弃用。jax.interpreters.mlir.hlo
和jax.interpreters.mlir.func_dialect
是意外导出,已被删除。 如果需要,它们可以从jax.extend.mlir
中获得。jax.interpreters.mlir.custom_call
已弃用。 应该使用jax.ffi
提供的 API。不再支持弃用的将
jax.ffi.ffi_call()
与内联参数一起使用。ffi_call()
现在无条件返回一个可调用对象。jax.lib.xla_client
中的以下导出已弃用:get_topology_for_devices
、heap_profile
、mlir_api_version
、Client
、CompileOptions
、DeviceAssignment
、Frame
、HloSharding
、OpSharding
、Traceback
。jax.util
中的以下内部 API 已弃用:HashableFunction
、as_hashable_function
、cache
、safe_map
、safe_zip
、split_dict
、split_list
、split_list_checked
、split_merge
、subvals
、toposort
、unzip2
、wrap_name
和wraps
。jax.dlpack.to_dlpack
已弃用。 您通常可以将 JAXArray
直接传递到另一个框架的from_dlpack
函数。 如果您需要to_dlpack
的功能,请使用数组的__dlpack__
属性。jax.lax.infeed
、jax.lax.infeed_p
、jax.lax.outfeed
和jax.lax.outfeed_p
已弃用,将在 JAX v0.7.0 中删除。几个先前已弃用的 API 已被删除,包括
来自
jax.lib.xla_client
:ArrayImpl
、FftType
、PaddingType
、PrimitiveType
、XlaBuilder
、dtype_to_etype
、ops
、register_custom_call_target
、shape_from_pyval
、Shape
、XlaComputation
。来自
jax.lib.xla_extension
:ArrayImpl
、XlaRuntimeError
。来自
jax
:jax.treedef_is_leaf
、jax.tree_flatten
、jax.tree_map
、jax.tree_leaves
、jax.tree_structure
、jax.tree_transpose
和jax.tree_unflatten
。 可以在jax.tree
或jax.tree_util
中找到替代项。来自
jax.core
:AxisSize
、ClosedJaxpr
、EvalTrace
、InDBIdx
、InputType
、Jaxpr
、JaxprEqn
、Literal
、MapPrimitive
、OpaqueTraceState
、OutDBIdx
、Primitive
、Token
、TRACER_LEAK_DEBUGGER_WARNING
、Var
、concrete_aval
、dedup_referents
、escaped_tracer_error
、extend_axis_env_nd
、full_lower
、get_referent
、jaxpr_as_fun
、join_effects
、lattice_join
、leaked_tracer_error
、maybe_find_leaked_tracers
、raise_to_shaped
、raise_to_shaped_mappings
、reset_trace_state
、str_eqn_compact
、substitute_vars_in_output_ty
、typecompat
和used_axis_names_jaxpr
。 大多数都没有公开的替代品,但少数可在jax.extend.core
中找到。pure_callback()
和ffi_call()
的vectorized
参数。 请改用vmap_method
参数。
jax 0.5.3 (2025 年 3 月 19 日)#
新功能
为
jax.lax.dynamic_slice()
、jax.lax.dynamic_update_slice()
和相关函数添加了allow_negative_indices
选项。 默认值为 true,与当前行为匹配。 如果设置为 false,则 JAX 不需要发出钳制负索引的代码,从而减小代码大小。为
jax.random.categorical()
添加了replace
选项,以启用无放回抽样。
jax 0.5.2 (2025 年 3 月 4 日)#
0.5.1 的补丁版本
错误修复
修复了 TPU 指标日志记录和
tpu-info
,这些在 0.5.1 中已损坏
jax 0.5.1 (2025 年 2 月 24 日)#
新功能
添加了一个实验性的
jax.experimental.custom_dce.custom_dce()
装饰器,以支持在 JAX 级别的死代码消除 (DCE) 下自定义不透明函数的行为。 有关更多详细信息,请参阅 #25956。在
jax.lax
中添加了底层归约 API:jax.lax.reduce_sum()
、jax.lax.reduce_prod()
、jax.lax.reduce_max()
、jax.lax.reduce_min()
、jax.lax.reduce_and()
、jax.lax.reduce_or()
和jax.lax.reduce_xor()
。jax.lax.linalg.qr()
和jax.scipy.linalg.qr()
现在支持在 CPU 和 GPU 上进行列主元 QR 分解。 有关更多信息,请参阅 #20282 和添加了
jax.random.multinomial()
。 有关更多详细信息,请参阅 #25955。
更改
JAX_CPU_COLLECTIVES_IMPLEMENTATION
和JAX_NUM_CPU_DEVICES
现在可以作为环境变量使用。 以前它们只能通过 jax.config 或标志指定。JAX_CPU_COLLECTIVES_IMPLEMENTATION
现在默认为'gloo'
,这意味着多进程 CPU 通信可以开箱即用。jax[tpu]
TPU 扩展不再依赖libtpu-nightly
包。 如果您的机器上存在此包,可以安全地将其删除;JAX 现在改用libtpu
。
弃用
内部函数
linear_util.wrap_init
和构造函数core.Jaxpr
现在必须接受非空的core.DebugInfo
kwarg。 在有限的时间内,如果使用不带调试信息的jax.extend.linear_util.wrap_init
,则会打印DeprecationWarning
。 此更改的下游影响是其他几个内部函数需要调试信息。 此更改不影响公共 API。 有关更多详细信息,请参阅 https://github.com/jax-ml/jax/issues/26480。在
jax.numpy.ndim()
、jax.numpy.shape()
和jax.numpy.size()
中,非类数组输入(例如列表、元组等)现在已弃用。
错误修复
TPU 运行时启动和关闭时间在 TPU v5e 及更高版本上应得到显着改善(从大约 17 秒到大约 8 秒)。 如果尚未设置,您可能需要在 VM 映像中启用透明大页 (
sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'
)。 我们希望在未来的版本中进一步改进这一点。如果未设置或设置为 -1 (即如果未启用 LRU 驱逐策略),则持久编译缓存不再写入访问时间文件。 这应该可以提高在使用具有大规模网络存储的缓存时的性能。
jax 0.5.0 (2025 年 1 月 17 日)#
自此版本起,JAX 现在使用 基于努力的版本控制。 由于此版本对 PRNG 密钥语义进行了重大更改,可能需要用户更新其代码,因此我们提高了 JAX 的“meso”版本以表示这一点。
重大更改
默认启用
jax_threefry_partitionable
(请参阅更新说明)。此版本放弃了对 Mac x86 轮子的支持。 Mac ARM 当然仍然受支持。 有关最近的讨论,请参阅 https://github.com/jax-ml/jax/discussions/22936。
促成此决定的两个关键因素
Mac x86 版本(仅限)有许多测试失败和崩溃。 我们宁愿不发布版本,也不愿发布损坏的版本。
Mac x86 硬件已停产,并且目前开发人员无法轻易获得。 因此,即使我们想这样做,也很难解决此类问题。
如果社区愿意帮助支持该平台,我们愿意重新添加对 Mac x86 的支持:特别是,我们需要 JAX 测试套件在 Mac x86 上干净地通过,然后我们才能再次发布版本。
更改
最低 NumPy 版本现在为 1.25。 NumPy 1.25 将保持最低支持版本,直到 2025 年 6 月。
最低 SciPy 版本现在为 1.11。 SciPy 1.11 将保持最低支持版本,直到 2025 年 6 月。
jax.numpy.einsum()
现在默认为optimize='auto'
而不是optimize='optimal'
。 这避免了在许多参数的情况下指数级扩展的跟踪时间 (#25214)。jax.numpy.linalg.solve()
不再支持右手侧的批量 1D 参数。 要在这些情况下恢复以前的行为,请使用solve(a, b[..., None]).squeeze(-1)
。
新功能
jax.numpy.fft.fftn()
、jax.numpy.fft.rfftn()
、jax.numpy.fft.ifftn()
和jax.numpy.fft.irfftn()
现在支持超过 3 个维度的变换,这以前是限制。 有关更多详细信息,请参阅 #25606。通过新的
jax.ffi.register_ffi_type_id()
函数,为 FFI 中的用户定义状态添加了支持。AOT 降低
.as_text()
方法现在支持debug_info
选项,以在输出中包含调试信息,例如源位置。
弃用
来自
jax.interpreters.xla
,abstractify
和pytype_aval_mappings
现在已弃用,已被jax.core
中同名的符号替换。jax.scipy.special.lpmn()
和jax.scipy.special.lpmn_values()
已弃用,原因是它们在 SciPy v1.15.0 中已弃用。 没有计划用新的 API 替换这些已弃用的函数。jax.extend.ffi
子模块已移动到jax.ffi
,以前的导入路径已弃用。
删除
jax_enable_memories
标志已删除,该标志的行为默认处于启用状态。来自
jax.lib.xla_client
,以前已弃用的Device
和XlaRuntimeError
符号已被删除;请改用jax.Device
和jax.errors.JaxRuntimeError
。在 JAX v0.4.32 中弃用后,
jax.experimental.array_api
模块已被删除。 自该版本以来,jax.numpy
直接支持 array API。
jax 0.4.38 (2024 年 12 月 17 日)#
重大更改
XlaExecutable.cost_analysis
现在返回dict[str, float]
(而不是单元素list[dict[str, float]]
)。
更改
添加了
jax.tree.flatten_with_path
和jax.tree.map_with_path
作为相应tree_util
函数的快捷方式。
弃用
内部
jax.core
命名空间中的许多 API 已被弃用。 大多数都是空操作,很少使用,或者可以用jax.extend.core
中同名的 API 替换;有关这些半公开扩展的兼容性保证的信息,请参阅jax.extend
的文档。几个先前已弃用的 API 已被删除,包括
来自
jax.core
:check_eqn
、check_type
、check_valid_jaxtype
和non_negative_dim
。来自
jax.lib.xla_bridge
:xla_client
和default_backend
。来自
jax.lib.xla_client
:_xla
和bfloat16
。来自
jax.numpy
:round_
。
新功能
jax.export.export()
可用于设备多态导出,分片使用jax.sharding.AbstractMesh()
构建。 请参阅 jax.export 文档。添加了
jax.lax.split()
。 这是jax.numpy.split()
的原始版本,添加的原因是它在自动微分期间产生更紧凑的转置。
jax 0.4.37 (2024 年 12 月 9 日)#
这是 jax 0.4.36 的补丁版本。 此版本仅发布了 “jax”。
错误修复
修复了如果参数名为
f
(#25329),jit
会出错的错误。修复了一个错误,如果用户为 flatten 和 flatten_with_path 注册了具有不同辅助数据的 pytree 节点类,则会在
jax.lax.while_loop()
中抛出index out of range
错误。固定了一个新的 libtpu 版本 (0.0.6),该版本修复了 TPU v6e 上的编译器错误。
jax 0.4.36 (2024 年 12 月 5 日)#
重大更改
此版本引入了 “stackless”,这是 JAX 跟踪机制的内部更改。 我们使跟踪调度纯粹成为上下文的函数,而不是上下文和数据的函数。 这使我们删除了大量用于管理数据相关跟踪的机制:级别、子级别、
post_process_call
、new_base_main
、custom_bind
等。 此更改应该只影响使用 JAX 内部构件的用户。如果您确实使用了 JAX 内部构件,则可能需要更新代码(有关如何执行此操作的线索,请参阅 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f)。 使用 JAX 库执行此操作时,也可能存在版本偏差问题。 如果您发现此更改破坏了未使用 JAX 内部构件的代码,请尝试使用
config.jax_data_dependent_tracing_fallback
标志作为解决方法,如果您需要帮助更新代码,请提交错误报告。自 2024 年 7 月起,
jax.experimental.jax2tf.convert()
与native_serialization=False
或enable_xla=False
已弃用,JAX 版本为 0.4.31。 现在我们取消了对这些用例的支持。 仍然支持使用本机序列化的jax2tf
。在
jax.interpreters.xla
中,在 JAX v0.4.31 中弃用后,xb
、xc
和xe
符号已被删除。 请改用xb = jax.lib.xla_bridge
、xc = jax.lib.xla_client
和xe = jax.lib.xla_extension
。已删除已弃用的模块
jax.experimental.export
。 它已在 JAX v0.4.30 中被jax.export
替换。 有关迁移到新 API 的信息,请参阅 迁移指南。在 v0.4.27 中弃用后,
jax.nn.softmax()
和jax.nn.log_softmax()
的initial
参数已被删除。现在,在类型化的 PRNG 密钥(即 :func:
jax.random.key
生成的密钥)上调用np.asarray
会引发错误。 以前,这会返回标量对象数组。jax.export
中以下已弃用的方法和函数已被删除jax.export.DisabledSafetyCheck.shape_assertions
:它已经没有效果。jax.export.Exported.lowering_platforms
:使用platforms
。jax.export.Exported.mlir_module_serialization_version
:使用calling_convention_version
。jax.export.Exported.uses_shape_polymorphism
:使用uses_global_constants
。jax.export.export()
的lowering_platforms
kwarg:请改用platforms
。
已删除
jax.export.symbolic_args_specs()
中的 kwargssymbolic_scope
和symbolic_constraints
。 它们已在 2024 年 6 月弃用。 请改用scope
和constraints
。自 0.4.30 版本以来已弃用的 tracer 哈希处理现在会导致
TypeError
。重构:JAX 构建 CLI (build/build.py) 现在使用子命令结构并替换以前的 build.py 用法。 运行
python build/build.py --help
了解更多详细信息。 新子命令选项的简要概述build
:构建 JAX 轮子包。 例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
:更新 requirements_lock.txt 文件。
jax.scipy.linalg.toeplitz()
现在对多维输入进行隐式批处理。 要恢复以前的行为,您可以在函数输入上调用jax.numpy.ravel()
。jax.scipy.special.gamma()
和jax.scipy.special.gammasgn()
现在为负整数输入返回 NaN,以匹配来自 https://github.com/scipy/scipy/pull/21827 的 SciPy 的行为。在 v0.4.26 中弃用后,
jax.clear_backends
已删除。我们从我们保证导出稳定性的自定义调用列表中删除了自定义调用 “__gpu$xla.gpu.triton”。 这是因为此自定义调用依赖于 Triton IR,而 Triton IR 不保证稳定。 如果您需要导出使用此自定义调用的代码,可以使用
disabled_checks
参数。 有关更多详细信息,请参阅文档。
新功能
jax.jit()
获取了一个新的compiler_options: dict[str, Any]
参数,用于将编译选项传递给 XLA。 目前,它尚未记录在案,并且可能处于变动之中。jax.tree_util.register_dataclass()
现在允许通过dataclasses.field()
内联声明元数据字段。 有关示例,请参阅函数文档。GPU 现在支持
jax.lax.linalg.eig()
和相关的jax.numpy
函数(jax.numpy.linalg.eig()
和jax.numpy.linalg.eigvals()
)。 有关更多详细信息,请参阅 #24663。添加了两个新的配置标志
jax_exec_time_optimization_effort
和jax_memory_fitting_effort
,以控制编译器在最大限度地减少执行时间和内存使用方面所花费的精力。 有效值介于 -1.0 和 1.0 之间,默认值为 0.0。
错误修复
修复了 LU 和 QR 分解的 GPU 实现会导致接近 int32 最大值的批量大小的索引溢出的错误。 有关更多详细信息,请参阅 #24843。
弃用
jax.lib.xla_extension.ArrayImpl
和jax.lib.xla_client.ArrayImpl
已弃用;请改用jax.Array
。jax.lib.xla_extension.XlaRuntimeError
已弃用;请改用jax.errors.JaxRuntimeError
。
jax 0.4.35 (2024 年 10 月 22 日)#
重大更改
现在,对于任何零维类数组对象,
jax.numpy.isscalar()
都返回 True。 以前,它仅对具有弱 dtype 的零维类数组对象返回 True。自 2024 年 3 月起,
jax.experimental.host_callback
已弃用,JAX 版本为 0.4.26。 现在我们已将其删除。 有关替代方案的讨论,请参阅 #20385。
更改
jax.lax.FftType
已作为 FFT 操作枚举的公共名称引入。 半公开 APIjax.lib.xla_client.FftType
已弃用。TPU:JAX 现在从
libtpu
包而不是libtpu-nightly
安装 TPU 支持。 在接下来的几个版本中,JAX 将固定libtpu-nightly
的空版本以及libtpu
,以简化过渡;该依赖项将在 2025 年第一季度删除。
弃用
半公开 API
jax.lib.xla_client.PaddingType
已弃用。 没有 JAX API 使用此类型,因此没有替代品。vmap
下jax.pure_callback()
和jax.extend.ffi.ffi_call()
的默认行为已弃用,这些函数的vectorized
参数也已弃用。 应改用vmap_method
参数以获得更好的定义行为。 有关更多详细信息,请参阅 #23881 中的讨论。半公开 API
jax.lib.xla_client.register_custom_call_target
已弃用。 请改用 JAX FFI。半公开 API
jax.lib.xla_client.dtype_to_etype
、jax.lib.xla_client.ops
、jax.lib.xla_client.shape_from_pyval
、jax.lib.xla_client.PrimitiveType
、jax.lib.xla_client.Shape
、jax.lib.xla_client.XlaBuilder
和jax.lib.xla_client.XlaComputation
已弃用。 请改用 StableHLO。
jax 0.4.34 (2024 年 10 月 4 日)#
新功能
此版本包括适用于 Python 3.13 的轮子。 尚不支持自由线程模式。
已添加
jax.errors.JaxRuntimeError
作为以前私有的XlaRuntimeError
类型的公共别名。
重大更改
jax_pmap_no_rank_reduction
标志默认设置为True
。pmap 结果上的 array[0] 现在引入了 reshape(请改用 array[0:1])。
每个分片形状(可通过 jax_array.addressable_shards 或 jax_array.addressable_data(0) 访问)现在具有前导 (1, …)。 相应地更新直接访问分片的代码。 每个分片形状的秩现在与全局形状的秩匹配,这与 jit 的行为相同。 这避免了将 pmap 的结果传递到 jit 时产生高昂的 reshape 开销。
jax.experimental.host_callback
自 2024 年 3 月起已弃用,始于 JAX 版本 0.4.26。现在,我们将--jax_host_callback_legacy
配置值的默认值设置为True
,这意味着如果您的代码使用jax.experimental.host_callback
API,这些 API 调用将通过新的jax.experimental.io_callback
API 实现。如果这破坏了您的代码,在非常有限的时间内,您可以将--jax_host_callback_legacy
设置为True
。我们很快将移除该配置选项,因此您应该转而使用新的 JAX 回调 API。请参阅 #20385 以获取更多讨论。
弃用
在
jax.numpy.trim_zeros()
中,非类数组参数或ndim != 1
的类数组参数现已弃用,将来会报错。内部美化打印工具
jax.core.pp_*
已被移除,此前已在 JAX v0.4.30 中弃用。jax.lib.xla_client.Device
已弃用;请改用jax.Device
。jax.lib.xla_client.XlaRuntimeError
已弃用。请改用jax.errors.JaxRuntimeError
。
删除
jax.xla_computation
已删除。自 0.4.30 JAX 版本中弃用以来已过去 3 个月。请使用 AOT API 以获得与jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您还可以使用
jax.stages.Lowered
的.out_info
属性来获取输出信息(如树结构、形状和 dtype)。对于跨后端 lowering,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jax.ShapeDtypeStruct
不再接受named_shape
参数。该参数仅供在 0.4.31 中移除的xmap
使用。jax.tree.map(f, None, non-None)
之前会发出DeprecationWarning
,现在在未来版本的 jax 中会引发错误。None
只是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map
将None
视为叶值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。jax.sharding.XLACompatibleSharding
已移除。请使用jax.sharding.Sharding
。
错误修复
修复了当提供非布尔输入并指定
dtype=bool
时,jax.numpy.cumsum()
会产生不正确输出的错误。编辑
jax.numpy.ldexp()
的实现以获得正确的梯度。
jax 0.4.33(2024 年 9 月 16 日)#
这是 jax 0.4.32 之上的一个补丁版本,修复了该版本中发现的两个错误。
在 JAX 0.4.32 固定的 libtpu 版本中发现了一个仅限 TPU 的数据损坏错误,该错误仅在同一作业中存在多个 TPU 切片时才会显现,例如,在多个 v5e 切片上进行训练时。此版本通过固定 libtpu
的固定版本修复了该问题。
此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。
jax 0.4.32(2024 年 9 月 11 日)#
注意:此版本已从 PyPi 中撤回,原因是 TPU 上存在数据损坏错误。有关更多详细信息,请参阅 0.4.33 发行说明。
新功能
添加了
jax.extend.ffi.ffi_call()
和jax.extend.ffi.ffi_lowering()
,以支持使用新的 外部函数接口 (FFI),以便从 JAX 与自定义 C++ 和 CUDA 代码进行交互。
更改
jax_enable_memories
标志默认设置为True
。jax.numpy
现在支持 Python Array API Standard 的 v2023.12 版本。有关更多信息,请参阅 Python Array API 标准。现在,CPU 后端上的计算可以在更多情况下异步分派。以前,非并行计算始终是同步分派的。您可以通过设置
jax.config.update('jax_cpu_enable_async_dispatch', False)
来恢复旧的行为。添加了新的
jax.process_indices()
函数,以替换在 JAX v0.2.13 中已弃用的jax.host_ids()
函数。为了与
numpy.fabs
的行为保持一致,jax.numpy.fabs
已被修改为不再支持complex dtypes
。jax.tree_util.register_dataclass
现在检查data_fields
和meta_fields
是否包含所有init=True
的 dataclass 字段,并且仅包含这些字段(如果nodetype
是 dataclass)。多个
jax.numpy
函数现在具有完整的ufunc
接口,包括add
、multiply
、bitwise_and
、bitwise_or
、bitwise_xor
、logical_and
、logical_and
和logical_and
。在未来的版本中,我们计划将这些扩展到其他 ufunc。添加了
jax.lax.optimization_barrier()
,它允许用户阻止编译器优化(如公共子表达式消除)并控制调度。
重大更改
MHLO MLIR 方言 (
jax.extend.mlir.mhlo
) 已移除。请改用stablehlo
方言。
弃用
不再允许将复数输入传递给
jax.numpy.clip()
和jax.numpy.hypot()
,自 JAX v0.4.27 起已弃用。已弃用以下 API
jax.lib.xla_bridge.xla_client
:请直接使用jax.lib.xla_client
。jax.lib.xla_bridge.get_backend
:请使用jax.extend.backend.get_backend()
。jax.lib.xla_bridge.default_backend
:请使用jax.extend.backend.default_backend()
。
已弃用
jax.experimental.array_api
模块,不再需要导入它即可使用 Array API。jax.numpy
直接支持 array API;有关更多信息,请参阅 Python Array API 标准。内部实用程序
jax.core.check_eqn
、jax.core.check_type
和jax.core.check_valid_jaxtype
现已弃用,将来将被移除。已弃用
jax.numpy.round_
,原因是在 NumPy 2.0 中移除了相应的 API。请改用jax.numpy.round()
。不推荐将 DLPack capsule 传递给
jax.dlpack.from_dlpack()
。jax.dlpack.from_dlpack()
的参数应为来自另一个框架的数组,该框架实现了__dlpack__
协议。
jaxlib 0.4.32(2024 年 9 月 11 日)#
注意:此版本已从 PyPi 中撤回,原因是 TPU 上存在数据损坏错误。有关更多详细信息,请参阅 0.4.33 发行说明。
重大更改
此版本的 jaxlib 切换到了 CPU 后端的新版本,该版本应编译速度更快,并能更好地利用并行性。如果您因这一更改而遇到任何问题,可以暂时通过设置环境变量
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
来启用旧的 CPU 后端。如果您需要这样做,请提交一个 JAX 错误报告,并附上重现说明。添加了 Hermetic CUDA 支持。Hermetic CUDA 使用特定的可下载 CUDA 版本,而不是用户本地安装的 CUDA。Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后在各种 Bazel 目标中将 CUDA 库和工具用作依赖项。这为 JAX 及其支持的 CUDA 版本实现了更可重现的构建。
更改
添加了 SparseCore 性能分析。
JAX 现在支持在 TPUv5p 芯片上分析 SparseCore 的性能。这些跟踪信息将在 Tensorboard Profiler 的 TraceViewer 中查看。
jax 0.4.31(2024 年 7 月 29 日)#
删除
xmap 已删除。请使用
shard_map()
作为替代。
更改
最低 CuDNN 版本为 v9.1。之前的版本也是如此,但我们现在正式声明此版本约束。
最低 Python 版本现在为 3.10。3.10 将在 2025 年 7 月之前保持最低支持版本。
最低 NumPy 版本现在为 1.24。NumPy 1.24 将在 2024 年 12 月之前保持最低支持版本。
最低 SciPy 版本现在为 1.10。SciPy 1.10 将在 2025 年 1 月之前保持最低支持版本。
jax.numpy.ceil()
、jax.numpy.floor()
和jax.numpy.trunc()
现在返回与输入相同 dtype 的输出,即不再将整数或布尔输入向上转换为浮点数。libdevice.10.bc
不再与 CUDA wheel 捆绑在一起。它必须作为本地 CUDA 安装的一部分安装,或者通过 NVIDIA 的 CUDA pip wheel 安装。jax.experimental.pallas.BlockSpec
现在期望在index_map
之前 传递block_shape
。旧的参数顺序已弃用,将在未来的版本中移除。更新了 gpu 设备的 repr,使其与 TPU/CPU 更一致。例如,
cuda(id=0)
现在将变为CudaDevice(id=0)
。为
jax.Array
添加了device
属性和to_device
方法,作为 JAX Array API 支持的一部分。
弃用
移除了一些先前已弃用的与多态形状相关的内部 API。从
jax.core
中:移除了canonicalize_shape
、dimension_as_value
、definitely_equal
和symbolic_equal_dim
。HLO lowering 规则不应再将单例 ir.Values 包装在元组中。而是返回未包装的单例 ir.Values。对包装值的支持将在未来版本的 JAX 中移除。
jax.experimental.jax2tf.convert()
与native_serialization=False
或enable_xla=False
现已弃用,此支持将在未来版本中移除。自 JAX 0.4.16(2023 年 9 月)以来,原生序列化一直是默认设置。先前已弃用的函数
jax.random.shuffle
已移除;请改用jax.random.permutation
和independent=True
。
jaxlib 0.4.31(2024 年 7 月 29 日)#
错误修复
修复了一个错误,该错误导致 jit 调度快速路径错误处理了 jit 的负 static_argnums。
修复了一个错误,该错误导致奇异矩阵批次的三角解产生无意义的有限值,而不是 inf 或 nan (#3589, #15429)。
jax 0.4.30(2024 年 6 月 18 日)#
更改
JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已提升至 0.4.0,但在此版本中已回滚,以便为 TensorFlow 和 JAX 用户提供更多时间来迁移到较新的 TensorFlow 版本。
jax.experimental.mesh_utils
现在可以为 TPU v5e 创建高效网格。jax 现在直接依赖于 jaxlib。此更改由 CUDA 插件切换启用:不再有多个 jaxlib 变体。您可以使用
pip install jax
安装仅 CPU 的 jax,无需任何额外功能。添加了用于导出和序列化 JAX 函数的 API。这以前存在于
jax.experimental.export
(即将弃用)中,现在将位于jax.export
中。请参阅文档。
弃用
内部美化打印工具
jax.core.pp_*
已弃用,将在未来的版本中移除。tracer 的哈希已弃用,将在未来的 JAX 版本中导致
TypeError
。以前是这种情况,但在最近的几个 JAX 版本中出现了意外的回归。jax.experimental.export
已弃用。请改用jax.export
。请参阅迁移指南。在大多数情况下,现在不推荐使用数组代替 dtype;例如,对于数组
x
和y
,x.astype(y)
将引发警告。要消除警告,请使用x.astype(y.dtype)
。jax.xla_computation
已弃用,将在未来的版本中移除。请使用 AOT API 以获得与jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您还可以使用
jax.stages.Lowered
的.out_info
属性来获取输出信息(如树结构、形状和 dtype)。对于跨后端 lowering,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jaxlib 0.4.30(2024 年 6 月 18 日)#
已停止支持单体 CUDA jaxlib。您必须使用基于插件的安装(
pip install jax[cuda12]
或pip install jax[cuda12_local]
)。
jax 0.4.29(2024 年 6 月 10 日)#
更改
我们预计这将是 JAX 和 jaxlib 支持单体 CUDA jaxlib 的最后一个版本。未来的版本将使用 CUDA 插件 jaxlib(例如
pip install jax[cuda12]
)。JAX 现在需要 ml_dtypes 版本 0.4.0 或更高版本。
移除了对
jax.experimental.export
API 旧用法的向后兼容性支持。不再可能使用from jax.experimental.export import export
,而应使用from jax.experimental import export
。移除的功能自 0.4.24 起已弃用。为
jax.tree.all()
&jax.tree_util.tree_all()
添加了is_leaf
参数。
弃用
jax.sharding.XLACompatibleSharding
已弃用。请使用jax.sharding.Sharding
。jax.experimental.Exported.in_shardings
已重命名为jax.experimental.Exported.in_shardings_hlo
。out_shardings
也是如此。旧名称将在 3 个月后移除。移除了一些先前已弃用的 API
jax.numpy.linalg.matrix_rank()
的tol
参数已被弃用,即将移除。请改用rtol
。jax.numpy.linalg.pinv()
的rcond
参数已被弃用,即将移除。请改用rtol
。已移除已弃用的
jax.config
子模块。要配置 JAX,请使用import jax
,然后通过jax.config
引用 config 对象。jax.random
API 不再接受批量键,而以前某些 API 意外地接受批量键。展望未来,我们建议在这种情况下显式使用jax.vmap()
。在
jax.scipy.special.beta()
中,x
和y
参数已重命名为a
和b
,以便与其他beta
API 保持一致。
新功能
添加了
jax.experimental.Exported.in_shardings_jax()
,用于从Exported
对象中存储的 HloSharding 构建可用于 JAX API 的分片。
jaxlib 0.4.29(2024 年 6 月 10 日)#
错误修复
修复了 XLA 错误地对某些串联操作进行分片的错误,该错误表现为累积归约的不正确输出 (#21403)。
修复了 XLA:CPU 错误编译某些 matmul 融合的错误 (https://github.com/openxla/xla/pull/13301)。
修复了 GPU 上的编译器崩溃问题 (https://github.com/jax-ml/jax/issues/21396)。
弃用
jax.tree.map(f, None, non-None)
现在发出DeprecationWarning
,并且在未来版本的 jax 中会引发错误。None
只是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map
将None
视为叶值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。
jax 0.4.28(2024 年 5 月 9 日)#
错误修复
恢复了对
make_jaxpr
的更改,该更改破坏了 Equinox (#21116)。
弃用和移除
现在已移除
jax.numpy.sort()
和jax.numpy.argsort()
的kind
参数。请改用stable=True
或stable=False
。从
jax.experimental.pallas.gpu
模块中移除了get_compute_capability
。请改用 GPU 设备的compute_capability
属性,该属性由jax.devices()
或jax.local_devices()
返回。jax.numpy.reshape()
的newshape
参数已被弃用,即将移除。请改用shape
。
更改
此版本的最低 jaxlib 版本为 0.4.27。
jaxlib 0.4.28(2024 年 5 月 9 日)#
错误修复
修复了 Python 3.10 或更早版本中 Array 和 JIT Python 对象类型名称中的内存损坏错误。
修复了 CUDA 12.4 下的警告
'+ptx84' is not a recognized feature for this target
。修复了 CPU 上编译速度慢的问题。
更改
Windows 版本现在使用 Clang 而不是 MSVC 构建。
jax 0.4.27(2024 年 5 月 7 日)#
新功能
添加了
jax.numpy.unstack()
和jax.numpy.cumulative_sum()
,遵循它们在 array API 2023 标准中的添加,该标准即将被 NumPy 采用。添加了一个新的配置选项
jax_cpu_collectives_implementation
,用于选择 CPU 后端使用的跨进程集合操作的实现。可用选项包括'none'
(默认)、'gloo'
和'mpi'
(需要 jaxlib 0.4.26)。如果设置为'none'
,则禁用跨进程集合操作。
更改
jax.pure_callback()
、jax.experimental.io_callback()
和jax.debug.callback()
现在使用jax.Array
而不是np.ndarray
。您可以通过在将参数传递给回调之前通过jax.tree.map(np.asarray, args)
转换参数来恢复旧的行为。complex_arr.astype(bool)
现在遵循与 NumPy 相同的语义,当complex_arr
等于0 + 0j
时返回 False,否则返回 True。core.Token
现在是一个重要的类,它封装了一个jax.Array
。它可以被创建并在计算中传入和传出,以建立依赖关系。单例对象core.token
已被移除,用户现在应该创建和使用新的core.Token
对象来代替。在 GPU 上,Threefry PRNG 的实现默认不再降低为内核调用。此选择可以在编译时降低运行时内存使用量。之前的行为(产生内核调用)可以使用
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
恢复。如果新的默认行为导致问题,请提交 bug。否则,我们计划在未来的版本中移除此标志。
弃用和移除
Pallas 现在完全使用 XLA 在 GPU 上编译内核。通过 Triton Python API 的旧的 lowering 过程已被移除,并且
JAX_TRITON_COMPILE_VIA_XLA
环境变量不再有任何效果。jax.numpy.clip()
有了一个新的参数签名:a
、a_min
和a_max
已被弃用,取而代之的是x
(仅限位置参数)、min
和max
(#20550)。JAX 数组的
device()
方法已被移除,该方法自 JAX v0.4.21 版本起已被弃用。请改用arr.devices()
。jax.nn.softmax()
和jax.nn.log_softmax()
的initial
参数已被弃用;现在支持 softmax 的空输入,无需设置此参数。在
jax.jit()
中,传递无效的static_argnums
或static_argnames
现在会导致错误,而不是警告。最低 jaxlib 版本现在是 0.4.23。
当向
jax.numpy.hypot()
函数传递复数值输入时,现在会发出弃用警告。当弃用完成后,这将引发错误。按照 NumPy 中的类似更改,
jax.numpy.nonzero()
、jax.numpy.where()
和相关函数的标量参数现在会引发错误。配置选项
jax_cpu_enable_gloo_collectives
已被弃用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
。在 JAX v0.4.22 版本中弃用后,
jax.Array.device_buffer
和jax.Array.device_buffers
方法已被移除。请改用jax.Array.addressable_shards
和jax.Array.addressable_data()
。在 JAX v0.4.21 版本中弃用关键字后,
jax.numpy.where
的condition
、x
和y
参数现在仅限位置参数。在
jax.lax.linalg
中的函数中,非数组参数现在必须通过关键字指定。以前,这会引发 DeprecationWarning。现在在几个 :func:
jax.numpy
API 中需要类数组参数,包括apply_along_axis()
、apply_over_axes()
、inner()
、outer()
、cross()
、kron()
和lexsort()
。
错误修复
当
copy=True
时,jax.numpy.astype()
现在始终返回副本。以前,当输出数组与输入数组具有相同 dtype 时,不会创建副本。这可能会导致一些内存使用量增加。默认值设置为copy=False
以保持向后兼容性。
jaxlib 0.4.27(2024 年 5 月 7 日)#
jax 0.4.26(2024 年 4 月 3 日)#
新功能
添加了
jax.numpy.trapezoid()
,与 NumPy 2.0 中添加的此函数一致。
更改
复数值
jax.numpy.geomspace()
现在选择与 NumPy 2.0 一致的对数螺旋分支。lax.rng_bit_generator
的行为,以及反过来'rbg'
和'unsafe_rbg'
PRNG 实现,在jax.vmap
下 已更改,以便在键上进行映射时,仅从批次中的第一个键生成随机数。文档现在使用
jax.random.key
来构造 PRNG 键数组,而不是jax.random.PRNGKey
。
弃用和移除
jax.tree_map()
已被弃用;请改用jax.tree.map
,或者为了与旧版本的 JAX 向后兼容,请使用jax.tree_util.tree_map()
。jax.clear_backends()
已被弃用,因为它不一定能实现其名称所暗示的功能,并且可能导致意外的后果,例如,它不会销毁现有的后端并释放相应的拥有资源。如果您只想清理编译缓存,请使用jax.clear_caches()
。为了向后兼容,或者您确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()
。jax.experimental.maps
模块和jax.experimental.maps.xmap
已被弃用。请使用jax.experimental.shard_map
或jax.vmap
以及spmd_axis_name
参数来表达 SPMD 设备并行计算。jax.experimental.host_callback
模块已被弃用。请改用 新的 JAX 外部回调。添加了JAX_HOST_CALLBACK_LEGACY
标志,以帮助过渡到新的回调。有关讨论,请参见 #20385。传递无法转换为 JAX 数组的参数给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
现在会导致异常。已移除已弃用的标志
jax_parallel_functions_output_gda
。此标志早已被弃用,并且没有任何作用;其使用是空操作。先前已弃用的导入
jax.interpreters.ad.config
和jax.interpreters.ad.source_info_util
现已被移除。请改用jax.config
和jax.extend.source_info_util
。JAX export 不再支持旧的序列化版本。版本 9 自 2023 年 10 月 27 日起已受支持,并自 2024 年 2 月 1 日起成为默认版本。请参阅 版本描述。此更改可能会破坏设置了低于 9 的特定 JAX 序列化版本的客户端。
jaxlib 0.4.26(2024 年 4 月 3 日)#
更改
JAX 现在仅支持 CUDA 12.1 或更高版本。已停止支持 CUDA 11.8。
JAX 现在支持 NumPy 2.0。
jax 0.4.25(2024 年 2 月 26 日)#
新功能
添加了 CUDA 数组接口 导入支持(需要 jaxlib 0.4.24)。
JAX 数组现在支持 NumPy 样式的标量布尔索引,例如
x[True]
或x[False]
。添加了
jax.tree
模块,该模块为引用jax.tree_util
中的函数提供了更方便的接口。jax.tree.transpose()
(即jax.tree_util.tree_transpose()
) 现在接受inner_treedef=None
,在这种情况下,将自动推断内部 treedef。
更改
Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
JAX_TRITON_COMPILE_VIA_XLA
环境变量设置为"0"
来恢复旧的行为。v0.4.24 版本中移除的
jax.interpreters.xla
中的几个已弃用的 API 已在 v0.4.25 版本中重新添加,包括backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
和XLAOp
。这些仍然被认为是已弃用的,并且将在未来提供更好的替代方案时再次移除。有关讨论,请参阅 #19816。
弃用和移除
jax.numpy.linalg.solve()
现在针对b.ndim > 1
的批量 1D 求解显示弃用警告。将来,这些将被视为批量 2D 求解。无论数组的大小如何,将非标量数组转换为 Python 标量现在都会引发错误。以前,对于大小为 1 的非标量数组,会引发弃用警告。这遵循了 NumPy 中的类似弃用。
先前已弃用的配置 API 已按照标准的 3 个月弃用周期移除(请参阅 API 兼容性)。这些包括
jax.config.config
对象和jax.config
的define_*_state
和DEFINE_*
方法。
通过
import jax.config
导入jax.config
子模块已被弃用。要配置 JAX,请使用import jax
,然后通过jax.config
引用配置对象。最低 jaxlib 版本现在是 0.4.20。
jaxlib 0.4.25(2024 年 2 月 26 日)#
jax 0.4.24(2024 年 2 月 6 日)#
更改
JAX lowering 到 StableHLO 不再依赖于物理设备。如果您的 primitive 在 lowering 规则(即传递给
mlir.register_lowering
的rule
参数的函数)中包装了 custom_partitioning 或 JAX 回调,则将您的 primitive 添加到jax._src.dispatch.prim_requires_devices_during_lowering
集合中。这是必需的,因为 custom_partitioning 和 JAX 回调需要物理设备才能在 lowering 期间创建Sharding
。在我们可以在没有物理设备的情况下创建Sharding
之前,这是一个临时状态。jax.numpy.argsort()
和jax.numpy.sort()
现在支持stable
和descending
参数。形状多态性处理的几个更改(在
jax.experimental.jax2tf
和jax.experimental.export
中使用)更清晰的符号表达式漂亮打印 (#19227)
添加了指定维度变量的符号约束的能力。这使得形状多态性更具表现力,并提供了一种解决不等式推理中限制的方法。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
随着符号约束的添加 (#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名称。来自不同作用域的符号表达式不能交互,例如,在算术运算中。作用域由
jax.experimental.jax2tf.convert()
、jax.experimental.export.symbolic_shape()
、jax.experimental.export.symbolic_args_specs()
引入。符号表达式e
的作用域可以使用e.scope
读取,并传递到上述函数中,以指示它们在给定作用域中构造符号表达式。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。简化和更快的相等性比较,如果两个符号维度的差的归一化形式简化为 0,则我们认为它们相等 (#19231;请注意,这可能会导致用户可见的行为更改)
改进了不确定不等式比较的错误消息 (#19235)。
已弃用
core.non_negative_dim
API(最近引入),并引入了core.max_dim
和core.min_dim
(#18953),以表达符号维度的max
和min
。您可以使用core.max_dim(d, 0)
代替core.non_negative_dim(d)
。已弃用
shape_poly.is_poly_dim
,赞成使用export.is_symbolic_dim
(#19282)。已弃用
export.args_specs
,赞成使用export.symbolic_args_specs ({jax-issue}
#19283`)。已弃用
shape_poly.PolyShape
和jax2tf.PolyShape
,对于多态形状规范,请使用字符串 (#19284)。JAX 默认原生序列化版本现在是 9。这与
jax.experimental.jax2tf
和jax.experimental.export
相关。请参阅 版本号描述。
重构了
jax.experimental.export
的 API。您现在应该使用from jax.experimental import export
,而不是from jax.experimental.export import export
。旧的导入方式将在 3 个月的弃用期内继续有效。具有
return_inverse = True
的jax.numpy.unique()
返回重塑为输入维度的逆索引,这与 NumPy 2.0 中numpy.unique()
的类似更改一致。jax.numpy.sign()
现在对于非零复数输入返回x / abs(x)
。这与 NumPy 2.0 版本中numpy.sign()
的行为一致。具有
return_sign=True
的jax.scipy.special.logsumexp()
现在对复数符号使用 NumPy 2.0 约定,x / abs(x)
。这与 SciPy v1.13 中scipy.special.logsumexp()
的行为一致。JAX 现在支持导入和导出布尔 DLPack 类型。以前,布尔值无法导入,并且导出为整数。
弃用和移除
许多先前已弃用的函数已被移除,遵循标准的 3+ 个月弃用周期(请参阅 API 兼容性)。这包括
来自
jax.core
:TracerArrayConversionError
、TracerIntegerConversionError
、UnexpectedTracerError
、as_hashable_function
、collections
、dtypes
、lu
、map
、namedtuple
、partial
、pp
、ref
、safe_zip
、safe_map
、source_info_util
、total_ordering
、traceback_util
、tuple_delete
、tuple_insert
和zip
。来自
jax.lax
:dtypes
、itertools
、naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
。jax.linear_util
子模块及其所有内容。jax.prng
子模块及其所有内容。来自
jax.random
:PRNGKeyArray
、KeyArray
、default_prng_impl
、threefry_2x32
、threefry2x32_key
、threefry2x32_p
、rbg_key
和unsafe_rbg_key
。来自
jax.tree_util
:register_keypaths
、AttributeKeyPathEntry
和GetItemKeyPathEntry
。来自
jax.interpreters.xla
:backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
、axis_groups
、ShapedArray
、ConcreteArray
、AxisEnv
、backend_compile
和XLAOp
。来自
jax.numpy
:NINF
、NZERO
、PZERO
、row_stack
、issubsctype
、trapz
和in1d
。来自
jax.scipy.linalg
:tril
和triu
。
先前已弃用的方法
PRNGKeyArray.unsafe_raw_array
已被移除。请改用jax.random.key_data()
。bool(empty_array)
现在会引发错误,而不是返回False
。这之前会引发弃用警告,并且遵循了 NumPy 中的类似更改。对 mhlo MLIR 方言的支持已被弃用。JAX 不再使用 mhlo 方言,而是使用 stablehlo。将来会移除引用 “mhlo” 的 API。请改用 “stablehlo” 方言。
jax.random
:直接将批量键传递给随机数生成函数(例如bits()
、gamma()
等)已被弃用,并将发出FutureWarning
。请使用jax.vmap
进行显式批处理。jax.lax.tie_in()
已被弃用:自 JAX v0.2.0 版本以来,它一直是一个空操作。
jaxlib 0.4.24(2024 年 2 月 6 日)#
更改
JAX 现在支持 CUDA 12.3 和 CUDA 11.8。已停止支持 CUDA 12.2。
cost_analysis
现在可以与交叉编译的Compiled
对象一起使用(即当使用带有拓扑对象的.lower().compile()
时,例如,从非 TPU 计算机编译用于 Cloud TPU)。添加了 CUDA 数组接口 导入支持(需要 jax 0.4.25)。
jax 0.4.23(2023 年 12 月 13 日)#
jaxlib 0.4.23(2023 年 12 月 13 日)#
修复了一个导致 GPU 编译器在编译期间产生冗长日志记录的错误。
jax 0.4.22(2023 年 12 月 13 日)#
弃用
JAX 数组的
device_buffer
和device_buffers
属性已被弃用。显式缓冲区已被更灵活的数组分片接口取代,但可以通过以下方式恢复以前的输出arr.device_buffer
变为arr.addressable_data(0)
arr.device_buffers
变为[x.data for x in arr.addressable_shards]
jaxlib 0.4.22(2023 年 12 月 13 日)#
jax 0.4.21(2023 年 12 月 4 日)#
新功能
添加了
jax.nn.squareplus
。
更改
最低 jaxlib 版本现在是 0.4.19。
发布的 wheels 现在使用 clang 而不是 gcc 构建。
强制在调用
jax.distributed.initialize()
之前设备后端尚未初始化。在云 TPU 环境中自动化
jax.distributed.initialize()
的参数。
弃用
先前已弃用的
sym_pos
参数已从jax.scipy.linalg.solve()
中移除。请改用assume_a='pos'
。传递
None
给jax.array()
或jax.asarray()
,无论是直接传递还是在列表或元组中传递,都已被弃用,现在会引发FutureWarning
。目前它会被转换为 NaN,但在未来将引发TypeError
。通过关键字参数将
condition
、x
和y
参数传递给jax.numpy.where
已被弃用,以匹配numpy.where
。将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
已被弃用,现在会引发DeprecationWaning
。目前,这些函数返回 False,未来将引发异常。JAX 数组的
device()
方法已被弃用。根据上下文,它可能被以下方法之一取代:jax.Array.devices()
返回数组使用的所有设备的集合。jax.Array.sharding
给出数组使用的分片配置。
jaxlib 0.4.21 (2023 年 12 月 4 日)#
更改
为了准备添加分布式 CPU 支持,JAX 现在将 CPU 设备与 GPU 和 TPU 设备同等对待,也就是说:
jax.devices()
包括分布式作业中存在的所有设备,即使是当前进程本地之外的设备。jax.local_devices()
仍然只包括当前进程本地的设备,因此如果对jax.devices()
的更改破坏了您的代码,您很可能想要改用jax.local_devices()
。CPU 设备现在在分布式作业中接收全局唯一的 ID 号;以前,CPU 设备会接收进程本地的 ID 号。
每个 CPU 设备的
process_index
现在将与同一进程中的任何 GPU 或 TPU 设备匹配;以前,CPU 设备的process_index
始终为 0。
在 NVIDIA GPU 上,JAX 现在优先为高达 1024x1024 的矩阵使用 Jacobi SVD 求解器。Jacobi 求解器似乎比非 Jacobi 版本更快。
错误修复
修复了当具有非有限值的数组传递给非对称特征值分解 (#18226) 时发生的错误/挂起。具有非有限值的数组现在会生成充满 NaN 的数组作为输出。
jax 0.4.20 (2023 年 11 月 2 日)#
jaxlib 0.4.20 (2023 年 11 月 2 日)#
错误修复
修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆。
jax 0.4.19 (2023 年 10 月 19 日)#
新功能
添加了
jax.typing.DTypeLike
,可用于注解可转换为 JAX dtypes 的对象。添加了
jax.numpy.fill_diagonal
。
更改
JAX 现在需要 SciPy 1.9 或更高版本。
错误修复
多控制器分布式 JAX 程序中只有进程 0 会写入持久编译缓存条目。这修复了当缓存放置在网络文件系统(如 GCS)上时的写入争用。
cusolver 和 cufft 的版本检查在确定已安装的这些库的版本是否至少与构建 JAX 所针对的版本一样新时,不再考虑补丁版本。
jaxlib 0.4.19 (2023 年 10 月 19 日)#
更改
如果安装了 pip 安装的 NVIDIA CUDA 库 (nvidia-… 包),jaxlib 现在将始终优先选择它们,而不是任何其他 CUDA 安装,包括
LD_LIBRARY_PATH
中命名的安装。如果这导致问题并且目的是使用系统安装的 CUDA,则解决方法是删除 pip 安装的 CUDA 库包。
jax 0.4.18 (2023 年 10 月 6 日)#
jaxlib 0.4.18 (2023 年 10 月 6 日)#
更改
CUDA jaxlibs 现在依赖用户安装兼容的 NCCL 版本。如果使用推荐的
cuda12_pip
安装,则应自动安装 NCCL。目前,需要 NCCL 2.16 或更高版本。我们现在提供 Linux aarch64 wheels,包括带和不带 NVIDIA GPU 支持的版本。
jax.Array.item()
现在支持可选的索引参数。
弃用
jax.lax
中的许多内部实用程序和意外导出已被弃用,并将在未来版本中删除。jax.lax.dtypes
:请改用jax.dtypes
。jax.lax.itertools
:请改用itertools
。naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
是内部实用程序,现在已弃用,没有替代品。
错误修复
修复了 Cloud TPU 回归,其中编译会因 smem 而导致 OOM。
jax 0.4.17 (2023 年 10 月 3 日)#
新功能
添加了新的
jax.numpy.bitwise_count()
函数,与最近添加到 NumPy 的类似函数的 API 相匹配。
弃用
删除了已弃用的模块
jax.abstract_arrays
及其所有内容。jax.random
中的命名键构造函数已被弃用。请将impl
参数传递给jax.random.PRNGKey()
或jax.random.key()
random.threefry2x32_key(seed)
变为random.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
变为random.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
变为random.PRNGKey(seed, impl='unsafe_rbg')
更改
CUDA:JAX 现在验证它找到的 CUDA 库是否至少与构建 JAX 所针对的 CUDA 库一样新。如果找到较旧的库,JAX 会引发异常,因为这比神秘的故障和崩溃更可取。
删除了“未找到 GPU/TPU”警告。相反,如果 Linux 上找到了 NVIDIA GPU 或 Google TPU 但未使用,并且未指定
--jax_platforms
,则发出警告。jax.scipy.stats.mode()
现在在跨尺寸为 0 的轴取模时返回 0 计数,与 SciPy 1.11 中的scipy.stats.mode
的行为相匹配。大多数
jax.numpy
函数和属性现在都具有完全定义的类型存根。以前,这些函数和属性中的许多都被静态类型检查器(如mypy
和pytype
)视为Any
。
jaxlib 0.4.17 (2023 年 10 月 3 日)#
更改
此版本中添加了 Python 3.12 wheels。
CUDA 12 wheels 现在需要 CUDA 12.2 或更高版本以及 cuDNN 8.9.4 或更高版本。
错误修复
修复了初始化 JAX CPU 后端时来自 ABSL 的日志垃圾邮件。
jax 0.4.16 (2023 年 9 月 18 日)#
更改
添加了
jax.numpy.ufunc
,以及jax.numpy.frompyfunc()
,它可以将任何标量值函数转换为类似numpy.ufunc()
的对象,并具有诸如outer()
、reduce()
、accumulate()
、at()
和reduceat()
等方法 (#17054)。当不在 IPython 下运行时:当引发异常时,JAX 现在会从回溯中过滤掉其所有内部帧。(没有之前出现的“未过滤的堆栈跟踪”。)这应该产生更友好的回溯外观。有关示例,请参见 此处。可以通过设置
JAX_TRACEBACK_FILTERING=remove_frames
(用于两个单独的未过滤/过滤的回溯,这是旧的行为)或JAX_TRACEBACK_FILTERING=off
(用于一个未过滤的回溯)来更改此行为。jax2tf 默认序列化版本现在为 7,它引入了新的形状 安全断言。
传递给
jax.sharding.Mesh
的设备应该是可哈希的。这尤其适用于模拟设备或用户创建的设备。jax.devices()
已经是可哈希的。
重大更改
jax2tf 现在默认使用本机序列化。有关详细信息和覆盖默认机制,请参阅 jax2tf 文档。
选项
--jax_coordination_service
已被删除。现在始终为True
。jax.jaxpr_util
已从公共 JAX 命名空间中删除。JAX_USE_PJRT_C_API_ON_TPU
不再有任何影响(即,它始终默认为 true)。2021 年 12 月引入的向后兼容性标志
--jax_host_callback_ad_transforms
已被删除。
弃用
根据 NumPy NEP-52,已弃用多个
jax.numpy
APIjax.numpy.NINF
已被弃用。请改用-jax.numpy.inf
。jax.numpy.PZERO
已被弃用。请改用0.0
。jax.numpy.NZERO
已被弃用。请改用-0.0
。jax.numpy.issubsctype(x, t)
已被弃用。请使用jax.numpy.issubdtype(x.dtype, t)
。jax.numpy.row_stack
已被弃用。请改用jax.numpy.vstack
。jax.numpy.in1d
已被弃用。请改用jax.numpy.isin
。jax.numpy.trapz
已被弃用。请改用jax.scipy.integrate.trapezoid
。
jax.scipy.linalg.tril
和jax.scipy.linalg.triu
已被弃用,遵循 SciPy。请改用jax.numpy.tril
和jax.numpy.triu
。jax.lax.prod
在 JAX v0.4.11 中被弃用后已被删除。请改用内置的math.prod
。与为自定义 JAX 原语定义 HLO 降级规则相关的
jax.interpreters.xla
中的多个导出已被弃用。自定义原语应改用jax.interpreters.mlir
中的 StableHLO 降级实用程序定义。以下先前弃用的函数在三个月的弃用期后已被删除:
jax.abstract_arrays.ShapedArray
:请使用jax.core.ShapedArray
。jax.abstract_arrays.raise_to_shaped
:请使用jax.core.raise_to_shaped
。jax.numpy.alltrue
:请使用jax.numpy.all
。jax.numpy.sometrue
:请使用jax.numpy.any
。jax.numpy.product
:请使用jax.numpy.prod
。jax.numpy.cumproduct
:请使用jax.numpy.cumprod
。
弃用/删除
内部子模块
jax.prng
现已弃用。其内容可在jax.extend.random
中找到。内部子模块路径
jax.linear_util
已被弃用。请改用jax.extend.linear_util
(jax.extend:扩展模块 的一部分)jax.random.PRNGKeyArray
和jax.random.KeyArray
已被弃用。对于类型注解,请使用jax.Array
,对于类型化 prng 键的运行时检测,请使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
。方法
PRNGKeyArray.unsafe_raw_array
已被弃用。请改用jax.random.key_data()
。jax.experimental.pjit.with_sharding_constraint
已被弃用。请改用jax.lax.with_sharding_constraint
。内部实用程序
jax.core.is_opaque_dtype
和jax.core.has_opaque_dtype
已被删除。不透明的 dtypes 已重命名为扩展的 dtypes;请改用jnp.issubdtype(dtype, jax.dtypes.extended)
(自 jax v0.4.14 起可用)。实用程序
jax.interpreters.xla.register_collective_primitive
已被删除。此实用程序在最近的 JAX 版本中没有执行任何有用的操作,可以安全地删除对其的调用。内部子模块路径
jax.linear_util
已被弃用。请改用jax.extend.linear_util
(jax.extend:扩展模块 的一部分)
jaxlib 0.4.16 (2023 年 9 月 18 日)#
更改
通过实验性 jax sparse API 进行的稀疏 CSR 矩阵乘法不再在 NVIDIA GPU 上使用确定性算法。进行此更改是为了提高与 CUDA 12.2.1 的兼容性。
错误修复
修复了 Windows 上由于与乱序部分和 IMAGE_REL_AMD64_ADDR32NB 重定位相关的致命 LLVM 错误导致的崩溃 (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14 (2023 年 7 月 27 日)#
更改
jax.jit
接受donate_argnames
作为参数。它的语义与static_argnames
类似。如果未提供 donate_argnums 和 donate_argnames,则不会捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 会使用inspect.signature(fun)
来查找与 donate_argnames 对应的任何位置参数(反之亦然)。如果同时提供了 donate_argnums 和 donate_argnames,则不使用 inspect.signature,并且只有 donate_argnums 或 donate_argnames 中列出的实际参数才会被捐赠。jax.random.gamma()
已被重构为更高效的算法,具有更强大的端点行为 (#16779)。这意味着对于给定的key
,JAX v0.4.13 和 v0.4.14 之间返回的gamma
和相关采样器(包括jax.random.ball()
、jax.random.beta()
、jax.random.chisquare()
、jax.random.dirichlet()
、jax.random.generalized_normal()
、jax.random.loggamma()
、jax.random.t()
)的值序列将发生变化。
删除
由于
in_axis_resources
和out_axis_resources
自弃用以来已超过 3 个月,因此已从 pjit 中删除。请使用in_shardings
和out_shardings
作为替代。这是一个安全且简单的名称替换。它不会更改当前 pjit 的任何语义,也不会破坏任何代码。您仍然可以将PartitionSpecs
传递给 in_shardings 和 out_shardings。
弃用
根据 https://jax.net.cn/en/latest/deprecation.html,已删除 Python 3.8 支持
根据 https://jax.net.cn/en/latest/deprecation.html,JAX 现在需要 NumPy 1.22 或更高版本
在 JAX 0.4.7 版本中弃用后,不再支持按位置将可选参数传递给
jax.numpy.ndarray.at()
。例如,请使用x.at[i].get(indices_are_sorted=True)
,而不是x.at[i].get(True)
以下
jax.Array
方法在 JAX v0.4.5 中弃用后已被删除:jax.Array.broadcast
:请改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:请改用jax.lax.broadcast_in_dim()
。jax.Array.split
:请改用jax.numpy.split()
。
以下 API 在先前弃用后已被删除:
jax.ad
:请使用jax.interpreters.ad
。jax.curry
:请使用curry = lambda f: partial(partial, f)
。jax.partial_eval
:请使用jax.interpreters.partial_eval
。jax.pxla
:请使用jax.interpreters.pxla
。jax.xla
:请使用jax.interpreters.xla
。jax.ShapedArray
:请使用jax.core.ShapedArray
。jax.interpreters.pxla.device_put
:请使用jax.device_put()
。jax.interpreters.pxla.make_sharded_device_array
:请使用jax.make_array_from_single_device_arrays()
。jax.interpreters.pxla.ShardedDeviceArray
:请使用jax.Array
。jax.numpy.DeviceArray
:请使用jax.Array
。jax.stages.Compiled.compiler_ir
:请使用jax.stages.Compiled.as_text()
。
重大更改
JAX 现在需要 ml_dtypes 版本 0.2.0 或更高版本。
为了修复一个极端情况,如果第二个和第三个参数是可调用的,即使其他操作数也是可调用的,则对具有五个参数的
jax.lax.cond()
的调用将始终解析为“公共操作数”cond
行为(如文档所述)。请参阅 #16413。已删除已弃用的配置选项
jax_array
和jax_jit_pjit_api_merge
,它们没有任何作用。这些选项在许多版本中默认都为 true。
新功能
JAX 现在支持配置标志 –jax_serialization_version 和 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本 (#16746)。
如果序列化版本至少为 7,则在形状多态性存在的情况下,jax2tf 现在生成的代码会检查某些形状约束。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
jaxlib 0.4.14 (2023 年 7 月 27 日)#
弃用
根据 https://jax.net.cn/en/latest/deprecation.html,已删除 Python 3.8 支持
jax 0.4.13 (2023 年 6 月 22 日)#
更改
jax.jit
现在允许将None
传递给in_shardings
和out_shardings
。语义如下:对于 in_shardings,JAX 会将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
jax.experimental.pjit.pjit
也允许将None
传递给in_shardings
和out_shardings
。语义如下:如果未提供 mesh 上下文管理器,则 JAX 可以自由选择它想要的任何分片。
对于 in_shardings,JAX 会将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
如果提供了 mesh 上下文管理器,则 None 将意味着该值将在 mesh 的所有设备上复制。
Executable.cost_analysis() 在 Cloud TPU 上工作
如果正在使用非白名单的
jaxlib
插件,则添加警告。添加了
jax.tree_util.tree_leaves_with_path
。None
不是jax.experimental.multihost_utils.host_local_array_to_global_array
或jax.experimental.multihost_utils.global_array_to_host_local_array
的有效输入。如果您想复制输入,请使用jax.sharding.PartitionSpec()
。
错误修复
修复了 CUDA 12 版本中不正确的 wheel 名称 (#16362);正确的 wheel 名称为
cudnn89
而不是cudnn88
。
弃用
jax.experimental.jax2tf.convert()
的native_serialization_strict_checks
参数已被弃用,转而使用新的native_serializaation_disabled_checks
(#16347)。
jaxlib 0.4.13 (2023 年 6 月 22 日)#
更改
为
jaxlib
Pypi 发布版本添加了 Windows CPU 专用 wheels。
错误修复
__cuda_array_interface__
在之前的 jaxlib 版本中存在问题,现已修复 (#16440)。在 NVIDIA GPU 上,默认启用并发 CUDA 内核追踪。
jax 0.4.12 (2023 年 6 月 8 日)#
更改
弃用
jax.abstract_arrays
及其内容现已弃用。请参阅 :mod:jax.core
中的相关功能。jax.numpy.alltrue
: 请使用jax.numpy.all
。这是为了跟进 NumPy 1.25.0 版本中对numpy.alltrue
的弃用。jax.numpy.sometrue
: 请使用jax.numpy.any
。这是为了跟进 NumPy 1.25.0 版本中对numpy.sometrue
的弃用。jax.numpy.product
: 请使用jax.numpy.prod
。这是为了跟进 NumPy 1.25.0 版本中对numpy.product
的弃用。jax.numpy.cumproduct
: 请使用jax.numpy.cumprod
。这是为了跟进 NumPy 1.25.0 版本中对numpy.cumproduct
的弃用。jax.sharding.OpShardingSharding
已被移除,因为它已被弃用 3 个月。
jaxlib 0.4.12 (2023 年 6 月 8 日)#
更改
包含用于 Hopper (SM 版本 9.0+) GPU 的 PTX/SASS。之前版本的 jaxlib 应该可以在 Hopper 上工作,但在首次执行 JAX 操作时会有较长的 JIT 编译延迟。
错误修复
修复了 Python 3.11 下 JAX 生成的 Python 回溯中不正确的源代码行信息。
修复了在 JAX 生成的 Python 回溯中打印帧的局部变量时崩溃的问题 (#16027)。
jax 0.4.11 (2023 年 5 月 31 日)#
弃用
以下 API 在 3 个月的弃用期后已被移除,符合 API 兼容性 政策
jax.experimental.PartitionSpec
: 请使用jax.sharding.PartitionSpec
。jax.experimental.maps.Mesh
: 请使用jax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: 请使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
: 请使用jax.sharding.PartitionSpec
。jax.experimental.pjit.FROM_GDA
。请改为传递分片的jax.Array
对象作为输入,并移除pjit
的可选in_shardings
参数。jax.interpreters.pxla.PartitionSpec
: 请使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
: 请使用jax.sharding.Mesh
jax.interpreters.xla.Buffer
: 请使用jax.Array
。jax.interpreters.xla.Device
: 请使用jax.Device
。jax.interpreters.xla.DeviceArray
: 请使用jax.Array
。jax.interpreters.xla.device_put
: 请使用jax.device_put
。jax.interpreters.xla.xla_call_p
: 请使用jax.experimental.pjit.pjit_p
。with_sharding_constraint
的axis_resources
参数已被移除。请改用shardings
。
jaxlib 0.4.11 (2023 年 5 月 31 日)#
更改
为
Device
添加了memory_stats()
方法。如果支持,这将返回一个字典,其中包含字符串统计名称和整数值,例如"bytes_in_use"
,如果平台不支持内存统计,则返回 None。返回的具体统计信息可能因平台而异。目前仅在 Cloud TPU 上实现。重新添加了对 CPU 设备上的 Python 缓冲区协议 (
memoryview
) 的支持。
jax 0.4.10 (2023 年 5 月 11 日)#
jaxlib 0.4.10 (2023 年 5 月 11 日)#
更改
修复了
'apple-m1' is not a recognized processor for this target (ignoring processor)
问题,该问题阻止了之前版本在 Mac M1 上运行。
jax 0.4.9 (2023 年 5 月 9 日)#
更改
标志 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap 已被移除。它们现在始终处于启用状态。
TPU 上奇异值分解 (SVD) 的准确性已得到提高(需要 jaxlib 0.4.9)。
弃用
jax.experimental.gda_serialization
已被弃用,并已重命名为jax.experimental.array_serialization
。请更改您的导入以使用jax.experimental.array_serialization
。pjit 的
in_axis_resources
和out_axis_resources
参数已被弃用。请分别使用in_shardings
和out_shardings
。函数
jax.numpy.msort
已被移除。它自 JAX v0.4.1 以来已被弃用。请改用jnp.sort(a, axis=0)
。in_parts
和out_parts
参数已从jax.xla_computation
中移除,因为它们仅与 sharded_jit 一起使用,而 sharded_jit 早已不再使用。instantiate_const_outputs
参数已从jax.xla_computation
中移除,因为它已经很久未使用了。
jaxlib 0.4.9 (2023 年 5 月 9 日)#
jax 0.4.8 (2023 年 3 月 29 日)#
重大更改
Cloud TPU 运行时的主要组件已升级。这在 Cloud TPU 上启用了以下新功能
jax.debug.print()
,jax.debug.callback()
, 和jax.debug.breakpoint()
现在可以在 Cloud TPU 上工作自动 TPU 内存碎片整理
jax.experimental.host_callback()
在使用新运行时组件的 Cloud TPU 上不再受支持。如果新的jax.debug
API 不足以满足您的用例,请在 JAX 问题跟踪器上提交问题。旧的运行时组件将在至少未来三个月内通过设置环境变量
JAX_USE_PJRT_C_API_ON_TPU=false
提供。如果您发现出于任何原因需要禁用新运行时,请在 JAX 问题跟踪器上告知我们。
更改
最低 jaxlib 版本已从 0.4.6 提升到 0.4.7。
弃用
CUDA 11.4 支持已删除。JAX GPU wheels 仅支持 CUDA 11.8 和 CUDA 12。如果 jaxlib 从源代码构建,则较旧的 CUDA 版本可能可以工作。
pmap 的
global_arg_shapes
参数仅在 sharded_jit 中有效,并已从 pmap 中移除。请迁移到 pjit 并从 pmap 中移除 global_arg_shapes。
jax 0.4.7 (2023 年 3 月 27 日)#
更改
根据 https://jax.net.cn/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_array
无法再禁用。jax.config.jax_jit_pjit_api_merge
无法再禁用。jax.experimental.jax2tf.convert()
现在支持native_serialization
参数,以使用 JAX 的原生 lowering 到 StableHLO,从而为整个 JAX 函数获取 StableHLO 模块,而不是将每个 JAX 原语 lowering 到 TensorFlow op。这简化了内部结构,并增加了您序列化的内容与 JAX 原生语义匹配的信心。请参阅 文档。作为此更改的一部分,配置标志--jax2tf_default_experimental_native_lowering
已重命名为--jax2tf_native_serialization
。JAX 现在依赖于
ml_dtypes
,其中包含 NumPy 类型的定义,如 bfloat16。这些定义以前是 JAX 内部的,但已拆分为单独的软件包,以便于与其他项目共享。JAX 现在需要 NumPy 1.21 或更高版本以及 SciPy 1.7 或更高版本。
弃用
类型
jax.numpy.DeviceArray
已弃用。请改用jax.Array
,它是它的别名。类型
jax.interpreters.pxla.ShardedDeviceArray
已弃用。请改用jax.Array
。按位置向
jax.numpy.ndarray.at()
传递额外参数已弃用。例如,请使用x.at[i].get(indices_are_sorted=True)
而不是x.at[i].get(True)
。jax.interpreters.xla.device_put
已弃用。请使用jax.device_put
。jax.interpreters.pxla.device_put
已弃用。请使用jax.device_put
。jax.experimental.pjit.FROM_GDA
已弃用。请传入分片的 jax.Arrays 作为输入,并移除 pjit 的in_shardings
参数,因为它是可选的。
jaxlib 0.4.7 (2023 年 3 月 27 日)#
更改
jaxlib 现在依赖于
ml_dtypes
,其中包含 NumPy 类型的定义,如 bfloat16。这些定义以前是 JAX 内部的,但已拆分为单独的软件包,以便于与其他项目共享。
jax 0.4.6 (2023 年 3 月 9 日)#
更改
jax.tree_util
现在包含一组 API,允许用户为其自定义 pytree 节点定义键。这包括tree_flatten_with_path
,它展平树并不仅返回每个叶子,还返回它们的键路径。tree_map_with_path
,它可以映射一个将键路径作为参数的函数。register_pytree_with_keys
,用于注册自定义 pytree 节点中键路径和叶子的外观。keystr
,用于美观地打印键路径。
jax2tf.call_tf()
有一个新的参数output_shape_dtype
(默认为None
),可用于声明结果的输出形状和类型。这使得jax2tf.call_tf()
能够在形状多态性的情况下工作。 (#14734)。
弃用
jax.tree_util
中的旧键路径 API 已被弃用,并将于 2023 年 3 月 10 日起 3 个月后移除register_keypaths
: 请改用jax.tree_util.register_pytree_with_keys()
。AttributeKeyPathEntry
: 请改用GetAttrKey
。GetitemKeyPathEntry
: 请改用SequenceKey
或DictKey
。
jaxlib 0.4.6 (2023 年 3 月 9 日)#
jax 0.4.5 (2023 年 3 月 2 日)#
弃用
jax.sharding.OpShardingSharding
已重命名为jax.sharding.GSPMDSharding
。jax.sharding.OpShardingSharding
将在 2023 年 2 月 17 日起 3 个月后移除。以下
jax.Array
方法已被弃用,并将于 2023 年 2 月 23 日起 3 个月后移除jax.Array.broadcast
:请改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:请改用jax.lax.broadcast_in_dim()
。jax.Array.split
:请改用jax.numpy.split()
。
jax 0.4.4 (2023 年 2 月 16 日)#
更改
jit
和pjit
的实现已合并。合并 jit 和 pjit 更改了 JAX 的内部结构,但不影响 JAX 的公共 API。之前,jit
是最终风格的原语。最终风格意味着 jaxpr 的创建尽可能延迟,并且转换堆叠在彼此之上。通过jit
-pjit
实现合并,jit
变为初始风格的原语,这意味着我们尽早追踪到 jaxpr。有关更多信息,请参阅 autodidax 中的此部分。迁移到初始风格应该简化 JAX 的内部结构,并使动态形状等功能的开发更容易。您只能通过环境变量禁用它,即os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
。合并必须通过环境变量禁用,因为它在导入时影响 JAX,因此需要在导入 jax 之前禁用它。with_sharding_constraint
的axis_resources
参数已被弃用。请改用shardings
。如果您将axis_resources
用作 arg,则无需更改。如果您将其用作 kwarg,请改用shardings
。axis_resources
将在 2023 年 2 月 13 日起 3 个月后移除。添加了
jax.typing
模块,其中包含用于 JAX 函数类型注释的工具。以下名称已被弃用
jax.xla.Device
和jax.interpreters.xla.Device
: 请使用jax.Device
。jax.experimental.maps.Mesh
。请改用jax.sharding.Mesh
。jax.experimental.pjit.NamedSharding
: 请使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
: 请使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
: 请使用jax.sharding.Mesh
。jax.interpreters.pxla.PartitionSpec
: 请使用jax.sharding.PartitionSpec
。
重大更改
现在要求归约函数(如 :func:
jax.numpy.sum
)的initial
参数为标量,与相应的 NumPy API 一致。先前针对非标量initial
值广播输出的行为是一个无意的实现细节 (#14446)。
jaxlib 0.4.4 (2023 年 2 月 16 日)#
重大更改
默认
jaxlib
构建版本中已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以从源代码构建jaxlib
并获得 Kepler 支持(通过build.py
的--cuda_compute_capabilities=sm_35
选项),但请注意 CUDA 12 已完全放弃对 Kepler GPU 的支持。
jax 0.4.3 (2023 年 2 月 8 日)#
重大更改
删除了
jax.scipy.linalg.polar_unitary()
,它是 scipy API 的已弃用 JAX 扩展。请改用jax.scipy.linalg.polar()
。
更改
jaxlib 0.4.3 (2023 年 2 月 8 日)#
jax.Array
现在具有非阻塞is_ready()
方法,如果数组已准备好,则返回True
(另请参见jax.block_until_ready()
)。
jax 0.4.2 (2023 年 1 月 24 日)#
重大更改
更改
jaxlib 0.4.2 (2023 年 1 月 24 日)#
更改
设置 JAX_USE_PJRT_C_API_ON_TPU=1 以启用新的 Cloud TPU 运行时,该运行时具有自动设备内存碎片整理功能。
jax 0.4.1 (2022 年 12 月 13 日)#
更改
根据 JAX 的 Python 和 NumPy 版本支持政策,已删除对 Python 3.7 的支持。
我们引入了
jax.Array
,它是一种统一的数组类型,包含了 JAX 中的DeviceArray
、ShardedDeviceArray
和GlobalDeviceArray
类型。jax.Array
类型有助于使并行性成为 JAX 的核心功能,简化和统一 JAX 内部结构,并允许我们统一jit
和pjit
。jax.Array
在 JAX 0.4 中已默认启用,并对pjit
API 进行了一些重大更改。jax.Array 迁移指南可以帮助您将代码库迁移到jax.Array
。您还可以查看 分布式数组和自动并行化教程,以了解新概念。PartitionSpec
和Mesh
现在已脱离实验阶段。新的 API 端点是jax.sharding.PartitionSpec
和jax.sharding.Mesh
。jax.experimental.maps.Mesh
和jax.experimental.PartitionSpec
已被弃用,并将在 3 个月后移除。with_sharding_constraint
的新公共端点是jax.lax.with_sharding_constraint
。如果将 ABSL 标志与
jax.config
一起使用,则在最初从 ABSL 标志填充 JAX 配置选项后,不再读取或写入 ABSL 标志值。此更改提高了读取jax.config
选项的性能,这些选项在 JAX 中被广泛使用。jax2tf.call_tf 函数现在为 TF lowering 使用与嵌入 JAX 计算所用平台相同的第一个 TF 设备。以前,它使用 JAX 默认后端的第 0 个设备。
许多
jax.numpy
函数现在将其参数标记为仅位置参数,与 NumPy 匹配。jnp.msort
现在已被弃用,跟进了 numpy 1.24 中对np.msort
的弃用。根据 API 兼容性 政策,它将在未来的版本中移除。它可以替换为jnp.sort(a, axis=0)
。
jaxlib 0.4.1 (2022 年 12 月 13 日)#
更改
根据 JAX 的 Python 和 NumPy 版本支持政策,已删除对 Python 3.7 的支持。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
的行为已更改为分配 XX% 的总 GPU 内存,而不是之前的行为,即使用当前可用的 GPU 内存来计算预分配。有关更多详细信息,请参阅 GPU 内存分配。已移除已弃用的方法
.block_host_until_ready()
。请改用.block_until_ready()
。
jax 0.4.0 (2022 年 12 月 12 日)#
此版本已被撤回。
jaxlib 0.4.0 (2022 年 12 月 12 日)#
此版本已被撤回。
jax 0.3.25 (2022 年 11 月 15 日)#
更改
jax.numpy.linalg.pinv()
现在支持hermitian
选项。jax.scipy.linalg.hessenberg()
现在仅在 CPU 上受支持。需要 jaxlib > 0.3.24。添加了新函数
jax.lax.linalg.hessenberg()
,jax.lax.linalg.tridiagonal()
, 和jax.lax.linalg.householder_product()
。Householder 归约目前仅限 CPU,而三对角归约仅在 CPU 和 GPU 上受支持。对于非平方矩阵,
svd
和jax.numpy.linalg.pinv
的梯度现在计算得更经济。
重大更改
删除了
jax_experimental_name_stack
配置选项。将字符串
axis_names
参数转换为jax.experimental.maps.Mesh
构造函数,转换为单例元组,而不是将字符串解包为字符轴名称序列。
jaxlib 0.3.25 (2022 年 11 月 15 日)#
更改
添加了对 CPU 和 GPU 上的三对角归约的支持。
添加了对 CPU 上的上 Hessenberg 归约的支持。
错误修复
修复了一个错误,该错误意味着 JAX 捕获的回溯中的帧在 Python 3.10+ 下被错误地映射到源代码行
jax 0.3.24 (2022 年 11 月 4 日)#
更改
JAX 的导入速度应该更快。我们现在延迟导入 scipy,这占 JAX 导入时间的很大一部分。
设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
可用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间为 1 秒或更长的计算将被缓存。如果未指定顺序,TPU 上
pmap
使用的默认设备顺序现在与单进程作业的jax.devices()
匹配。以前,这两个顺序不同,这可能导致不必要的副本或内存不足错误。要求顺序一致简化了问题。
重大更改
jax.numpy.gradient()
现在表现得更像jax.numpy
中的大多数其他函数,并且禁止传递列表或元组来代替数组 (#12958)在
jax.numpy.linalg
和jax.numpy.fft
中的函数现在统一要求输入为类数组(array-like):即列表和元组不能代替数组使用。这是 #7737 的一部分。
弃用
jax.sharding.MeshPspecSharding
已被重命名为jax.sharding.NamedSharding
。jax.sharding.MeshPspecSharding
名称将在 3 个月后移除。
jaxlib 0.3.24 (2022 年 11 月 4 日)#
更改
缓冲区捐赠现在在 CPU 上工作。这可能会破坏在 CPU 上标记缓冲区进行捐赠但依赖于捐赠未实现的的代码。
jax 0.3.23 (2022 年 10 月 12 日)#
更改
更新 Colab TPU 驱动版本以支持新的 jaxlib 版本。
jax 0.3.22 (2022 年 10 月 11 日)#
更改
在 TPU 初始化中添加
JAX_PLATFORMS=tpu,cpu
作为默认设置,以便在 TPU 无法初始化时 JAX 会引发错误,而不是回退到 CPU。设置JAX_PLATFORMS=''
以覆盖此行为并自动选择可用的后端(原始默认行为),或设置JAX_PLATFORMS=cpu
以始终使用 CPU,无论 TPU 是否可用。
弃用
在 JAX v0.3.8 中弃用的几个测试实用程序现在已从
jax.test_util
中移除。
jaxlib 0.3.22 (2022 年 10 月 11 日)#
jax 0.3.21 (2022 年 9 月 30 日)#
jax 0.3.20 (2022 年 9 月 28 日)#
jaxlib 0.3.20 (2022 年 9 月 28 日)#
jax 0.3.19 (2022 年 9 月 27 日)#
修复了所需的 jaxlib 版本。
jax 0.3.18 (2022 年 9 月 26 日)#
更改
提前 (Ahead-of-time) 降低和编译功能(在 #7733 中跟踪)是稳定且公开的。请参阅 概述 和
jax.stages
的 API 文档。引入了
jax.Array
,旨在用于 JAX 中数组类型的isinstance
检查和类型注解。请注意,这包括对jax.numpy.ndarray
的isinstance
工作方式的一些细微更改,因为jax.numpy.ndarray
现在是jax.Array
的简单别名。
重大更改
jax._src
不再导入到公共jax
命名空间中。这可能会破坏正在使用 JAX 内部组件的用户。jax.soft_pmap
已被删除。请改用pjit
或xmap
。jax.soft_pmap
没有文档记录。如果它有文档记录,则会提供弃用期。
jax 0.3.17 (2022 年 8 月 31 日)#
错误修复
修复了指数为零时
lax.pow
梯度的边角情况问题 (#12041)
重大更改
jax.checkpoint()
,也称为jax.remat()
,不再支持concrete
选项,遵循先前版本的弃用;请参阅 JEP 11830。
更改
添加了
jax.pure_callback()
,它允许从编译后的函数(例如,用jax.jit
或jax.pmap
装饰的函数)回调到纯 Python 函数。
弃用
已删除已弃用的
DeviceArray.tile()
方法。使用jax.numpy.tile()
(#11944)。DeviceArray.to_py()
已被弃用。请改用np.asarray(x)
。
jax 0.3.16#
重大更改
根据 弃用策略,已删除对 NumPy 1.19 的支持。请升级到 NumPy 1.20 或更高版本。
更改
添加了
jax.debug
,其中包含用于运行时值调试的实用程序,例如jax.debug.print()
和jax.debug.breakpoint()
。为 运行时值调试 添加了新的文档
弃用
jax.mask()
jax.shapecheck()
API 已被删除。请参阅 #11557。jax.experimental.loops
已被删除。有关替代 API,请参阅 #10278。jax.tree_util.tree_multimap()
已被删除。自 JAX 0.3.5 版本以来,它已被弃用,jax.tree_util.tree_map()
是直接替代品。移除了
jax.experimental.stax
;长期以来,它是jax.example_libraries.stax
的已弃用别名。移除了
jax.experimental.optimizers
;长期以来,它是jax.example_libraries.optimizers
的已弃用别名。jax.checkpoint()
,也称为jax.remat()
,默认情况下切换到新的实现,这意味着旧的实现已被弃用;请参阅 JEP 11830。
jax 0.3.15 (2022 年 7 月 22 日)#
更改
JaxTestCase
和JaxTestLoader
已从jax.test_util
中移除。这些类自 v0.3.1 以来已被弃用 (#11248)。添加了
jax.scipy.gaussian_kde
(#11237)。JAX 数组和内置集合 (
dict
、list
、set
、tuple
) 之间的二元运算现在在所有情况下都会引发TypeError
。以前,某些情况(特别是相等性和不等性)会返回与 NumPy 中类似操作不一致的布尔标量 (#11234)。作为顶级 JAX 包导入访问的几个
jax.tree_util
例程现在已被弃用,并将根据 API 兼容性 策略在未来的 JAX 版本中移除jax.treedef_is_leaf()
已被弃用,推荐使用jax.tree_util.treedef_is_leaf()
jax.tree_flatten()
已被弃用,推荐使用jax.tree_util.tree_flatten()
jax.tree_leaves()
已被弃用,推荐使用jax.tree_util.tree_leaves()
jax.tree_structure()
已被弃用,推荐使用jax.tree_util.tree_structure()
jax.tree_transpose()
已被弃用,推荐使用jax.tree_util.tree_transpose()
jax.tree_unflatten()
已被弃用,推荐使用jax.tree_util.tree_unflatten()
jax.scipy.linalg.solve()
的sym_pos
参数已被弃用,推荐使用assume_a='pos'
,遵循scipy.linalg.solve()
中的类似弃用。
jaxlib 0.3.15 (2022 年 7 月 22 日)#
jax 0.3.14 (2022 年 6 月 27 日)#
重大更改
jax.experimental.compilation_cache.initialize_cache()
不再支持max_cache_size_ bytes
,并且不会将其作为输入。当平台初始化失败时,
JAX_PLATFORMS
现在会引发异常。
更改
修复了与 NumPy 1.23 的兼容性问题。
jax.numpy.linalg.slogdet()
现在接受可选的method
参数,该参数允许在基于 LU 分解的实现和基于 QR 分解的实现之间进行选择。jax.numpy.linalg.qr()
现在支持mode="raw"
。pickle
、copy.copy
和copy.deepcopy
现在在使用 jax 数组时具有更完整的支持 (#10659)。特别是:pickle
和deepcopy
之前在DeviceArray
上使用时返回np.ndarray
对象;现在返回DeviceArray
对象。对于deepcopy
,复制的数组与原始数组位于同一设备上。对于pickle
,反序列化的数组将位于默认设备上。在函数转换(即跟踪代码)中,
deepcopy
和copy
以前是空操作 (no-ops)。现在它们使用与DeviceArray.copy()
相同的机制。在跟踪数组上调用
pickle
现在会导致显式的ConcretizationTypeError
。
奇异值分解 (SVD) 和对称/Hermitian 特征值分解的实现应该在 TPU 上显着加快速度,特别是对于 1000x1000 或更大的矩阵。两者现在都使用谱分治算法进行特征值分解 (QDWH-eig)。
jax.numpy.ldexp()
不再静默地将所有输入提升为 float64,而是对于 int32 或更小的整数输入,它会提升为 float32 (#10921)。为
jax.profiler.start_trace()
和jax.profiler.start_trace()
添加了create_perfetto_link
选项。使用后,分析器将生成 Perfetto UI 的链接以查看跟踪。更改了
jax.profiler.start_server(...)()
的语义,以全局存储 keepalive,而不是要求用户保留对其的引用。添加了
jax.random.ball()
。添加了
jax.default_device()
。添加了
python -m jax.collect_profile
脚本,以手动捕获程序跟踪,作为 TensorBoard UI 的替代方案。添加了
jax.named_scope
上下文管理器,它将分析器元数据添加到 Python 程序(类似于jax.named_call
)。在散点更新操作(即 :attr:
jax.numpy.ndarray.at
)中,不安全的隐式 dtype 转换已被弃用,现在会导致FutureWarning
。在未来的版本中,这将变成错误。不安全的隐式转换的一个示例是jnp.zeros(4, dtype=int).at[0].set(1.5)
,其中1.5
以前会被静默截断为1
。jax.experimental.compilation_cache.initialize_cache()
现在支持 gcs bucket 路径作为输入。jax.numpy.roots()
在系数具有前导零时,当strip_zeros=False
时,现在表现得更好 (#11215)。
jaxlib 0.3.14 (2022 年 6 月 27 日)#
-
x86-64 Mac wheels 现在需要 Mac OS 10.14 (Mojave) 或更高版本。Mac OS 10.14 于 2018 年发布,因此这应该不是一个非常苛刻的要求。
捆绑的 NCCL 版本已更新到 2.12.12,修复了一些死锁。
Python flatbuffers 包不再是 jaxlib 的依赖项。
jax 0.3.13 (2022 年 5 月 16 日)#
jax 0.3.12 (2022 年 5 月 15 日)#
jax 0.3.11 (2022 年 5 月 15 日)#
更改
jax.lax.eigh()
现在接受可选的sort_eigenvalues
参数,该参数允许用户选择退出 TPU 上的特征值排序。
弃用
jax.lax.linalg
中函数的非数组参数现在标记为仅关键字参数。作为向后兼容性步骤,以位置方式传递仅关键字参数会产生警告,但在未来的 JAX 版本中,以位置方式传递仅关键字参数将失败。但是,大多数用户应首选使用jax.numpy.linalg
。jax.scipy.linalg.polar_unitary()
是 scipy API 的 JAX 扩展,已被弃用。请改用jax.scipy.linalg.polar()
。
jax 0.3.10 (2022 年 5 月 3 日)#
jaxlib 0.3.10 (2022 年 5 月 3 日)#
jax 0.3.9 (2022 年 5 月 2 日)#
更改
为 GlobalDeviceArray 添加了对完全异步检查点 (checkpointing) 的支持。
jax 0.3.8 (2022 年 4 月 29 日)#
更改
TPU 上的
jax.numpy.linalg.svd()
使用 qdwh-svd 求解器。TPU 上的
jax.numpy.linalg.cond()
现在接受复数输入。TPU 上的
jax.numpy.linalg.pinv()
现在接受复数输入。TPU 上的
jax.numpy.linalg.matrix_rank()
现在接受复数输入。jax.experimental.maps.mesh
已被删除。请使用jax.experimental.maps.Mesh
。有关更多信息,请参阅 https://jax.net.cn/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh。jax.scipy.linalg.qr()
当mode='r'
时,现在返回长度为 1 的元组而不是原始数组,以便与scipy.linalg.qr
的行为匹配 (#10452)jax.numpy.take_along_axis()
现在接受可选的mode
参数,该参数指定越界索引的行为。默认情况下,对于越界索引将返回无效值(例如,NaN)。在以前版本的 JAX 中,无效索引被钳制到范围内。可以通过传递mode="clip"
来恢复以前的行为。jax.numpy.take()
现在默认为mode="fill"
,对于越界索引,它返回无效值(例如,NaN)。散点操作,例如
x.at[...].set(...)
,现在具有"drop"
语义。这对散点操作本身没有影响,但这意味着当微分时,散点的梯度将为越界索引产生零余切。以前,越界索引被钳制到范围内以进行梯度计算,这在数学上是不正确的。jax.numpy.take_along_axis()
如果其索引不是整数类型,现在会引发TypeError
,这与numpy.take_along_axis()
的行为匹配。以前,非整数索引会被静默转换为整数。jax.numpy.ravel_multi_index()
如果其dims
参数不是整数类型,现在会引发TypeError
,这与numpy.ravel_multi_index()
的行为匹配。以前,非整数dims
会被静默转换为整数。jax.numpy.split()
现在会在其axis
参数不是整数类型时引发TypeError
,与numpy.split()
的行为保持一致。之前,非整数类型的axis
参数会被静默地转换为整数。jax.numpy.indices()
现在会在其维度不是整数类型时引发TypeError
,与numpy.indices()
的行为保持一致。之前,非整数类型的维度会被静默地转换为整数。jax.numpy.diag()
现在会在其k
参数不是整数类型时引发TypeError
,与numpy.diag()
的行为保持一致。之前,非整数类型的k
参数会被静默地转换为整数。
弃用
jax.test_util
中提供的许多函数和对象现在已被弃用,并在导入时会引发警告。这包括cases_from_list
、check_close
、check_eq
、device_under_test
、format_shape_dtype_string
、rand_uniform
、skip_on_devices
、with_config
、xla_bridge
和_default_tolerance
(#10389)。这些以及之前已弃用的JaxTestCase
、JaxTestLoader
和BufferDonationTestCase
将在未来的 JAX 版本中移除。这些实用程序中的大多数可以被调用标准 Python 和 NumPy 测试实用程序来替代,例如在unittest
、absl.testing
、numpy.testing
等中找到的实用程序。JAX 特定的功能,例如设备检查,可以通过使用公共 API(例如jax.devices()
)来替代。许多已弃用的实用程序仍然存在于jax._src.test_util
中,但这些不是公共 API,因此可能会在未来的版本中更改或移除,恕不另行通知。
jax 0.3.7 (2022年4月15日)#
更改
修复了当传递给
jax.numpy.take_along_axis()
的索引被广播时出现的性能问题 (#10281)。jax.scipy.special.expit()
和jax.scipy.special.logit()
现在要求它们的参数为标量或 JAX 数组。它们现在还会将整数参数提升为浮点数。DeviceArray.tile()
方法已被弃用,因为 NumPy 数组没有tile()
方法。作为替代方案,请使用jax.numpy.tile()
(#10266)。
jaxlib 0.3.7 (2022年4月15日)#
更改
Linux wheels 现在按照
manylinux2014
标准构建,而不是manylinux2010
。
jax 0.3.6 (2022年4月12日)#
jax 0.3.5 (2022年4月7日)#
更改
添加了
jax.random.loggamma()
,并改进了jax.random.beta()
和jax.random.dirichlet()
在小参数值下的行为 (#9906)。私有的
lax_numpy
子模块不再在jax.numpy
命名空间中公开 (#10029)。添加了数组创建例程
jax.numpy.frombuffer()
、jax.numpy.fromfunction()
和jax.numpy.fromstring()
(#10049)。DeviceArray.copy()
现在返回DeviceArray
而不是np.ndarray
(#10069)jax.experimental.sharded_jit
已被弃用,并将很快移除。
弃用
jax.nn.normalize()
正在被弃用。请改用jax.nn.standardize()
(#9899)。jax.tree_util.tree_multimap()
已被弃用。请改用jax.tree_util.tree_map()
(#5746)。jax.experimental.sharded_jit
已被弃用。请改用pjit
。
jaxlib 0.3.5 (2022年4月7日)#
jax 0.3.4 (2022年3月18日)#
jax 0.3.3 (2022年3月17日)#
jax 0.3.2 (2022年3月16日)#
更改
函数
jax.ops.index_update
,jax.ops.index_add
(在 0.2.22 版本中已弃用) 已被移除。请使用 JAX 数组上的.at
属性 来代替,例如x.at[idx].set(y)
。将
jax.experimental.ann.approx_*_k
移动到jax.lax
中。这些函数是jax.lax.top_k
的优化替代方案。jax.numpy.broadcast_arrays()
和jax.numpy.broadcast_to()
现在需要标量或类数组输入,如果传递列表将失败(#7737 的一部分)。标准的 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。
pjit
现在可以在 CPU 上工作(除了之前的 TPU 和 GPU 支持)。
jaxlib 0.3.2 (2022年3月16日)#
更改
XlaComputation.as_hlo_text()
现在支持通过传递布尔标志print_large_constants=True
来打印大型常量。
弃用
JAX 数组上的
.block_host_until_ready()
方法已被弃用。请改用.block_until_ready()
。
jax 0.3.1 (2022年2月18日)#
更改
jax.test_util.JaxTestCase
和jax.test_util.JaxTestLoader
现在已被弃用。建议的替代方案是直接使用parametrized.TestCase
。对于依赖于自定义断言(例如JaxTestCase.assertAllClose()
)的测试,建议的替代方案是使用标准的 NumPy 测试实用程序(例如numpy.testing.assert_allclose()
),它们可以直接与 JAX 数组一起使用 (#9620)。jax.test_util.JaxTestCase
现在默认设置jax_numpy_rank_promotion='raise'
(#9562)。要恢复之前的行为,请使用新的jax.test_util.with_config
装饰器@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
添加了
jax.scipy.linalg.schur()
,jax.scipy.linalg.sqrtm()
,jax.scipy.signal.csd()
,jax.scipy.signal.stft()
,jax.scipy.signal.welch()
。
jax 0.3.0 (2022年2月10日)#
jaxlib 0.3.0 (2022年2月10日)#
更改
构建 jaxlib 现在需要 Bazel 5.0.0。
jaxlib 版本已提升至 0.3.0。请参阅 设计文档 以了解详细说明。
jax 0.2.28 (2022年2月1日)#
-
jax.jit(f).lower(...).compiler_ir()
现在默认使用 MHLO dialect,如果没有传递dialect=
参数。jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
现在返回 MLIRir.Module
对象,而不是其字符串表示形式。
jaxlib 0.1.76 (2022年1月27日)#
新功能
包含用于 NVidia 计算能力 8.0 GPU (例如 A100) 的预编译 SASS。移除了用于计算能力 6.1 的预编译 SASS,以避免增加计算能力的数量:计算能力为 6.1 的 GPU 可以使用 6.0 SASS。
使用 jaxlib 0.1.76,JAX 默认使用 MHLO MLIR dialect 作为其主要目标编译器 IR。
重大更改
根据 弃用政策,已停止支持 NumPy 1.18。请升级到受支持的 NumPy 版本。
错误修复
修复了由不同路径构造的看似相同的 pytreedef 对象无法比较相等的错误 (#9066)。
JAX jit 缓存要求两个静态参数具有相同的类型才能命中缓存 (#9311)。
jax 0.2.27 (2022年1月18日)#
重大更改
根据 弃用政策,已停止支持 NumPy 1.18。请升级到受支持的 NumPy 版本。
简化了 host_callback 原语,取消了对 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,只 tap primals。可以通过设置
JAX_HOST_CALLBACK_AD_TRANSFORMS
环境变量或--jax_host_callback_ad_transforms
标志来获得旧的行为(在有限的时间内)。此外,还添加了关于如何使用 JAX 自定义 AD API 实现旧行为的文档 (#8678)。排序现在匹配 NumPy 对于
0.0
和NaN
的行为,无论位表示形式如何。特别是,0.0
和-0.0
现在被视为等价,而之前-0.0
被视为小于0.0
。此外,所有NaN
表示形式现在都被视为等价,并排序到数组的末尾。之前,负NaN
值被排序到数组的前面,并且具有不同内部位表示形式的NaN
值不被视为等价,并根据这些位模式进行排序 (#9178)。jax.numpy.unique()
现在以与 NumPy 1.21 及更高版本中的np.unique
相同的方式处理NaN
值:在唯一化输出中最多出现一个NaN
值 (#9184)。
错误修复
host_callback 现在支持 ad_checkpoint.checkpoint (#8907)。
新功能
添加
jax.block_until_ready
({jax-issue}`#8941)添加了一个新的调试标志/环境变量
JAX_DUMP_IR_TO=/path
。如果设置,JAX 会将其为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件中。将
jax.ensure_compile_time_eval
添加到公共 API 中 (#7987)。jax2tf 现在支持标志 jax2tf_associative_scan_reductions,以更改关联归约(例如 jnp.cumsum)的降低行为,使其在 CPU 和 GPU 上表现得像 JAX(使用关联扫描)。有关更多详细信息,请参阅 jax2tf README (#9189)。
jaxlib 0.1.75 (2021年12月8日)#
新功能
支持 Python 3.10。
jax 0.2.26 (2021年12月8日)#
jaxlib 0.1.74 (2021年11月17日)#
启用了 GPU 之间的对等复制。以前,GPU 复制通过主机 bounced,这通常较慢。
为 JAX 使用添加了实验性的 MLIR Python 绑定。
jax 0.2.25 (2021年11月10日)#
jax 0.2.24 (2021年10月19日)#
jaxlib 0.1.73 (2021年10月18日)#
jaxlib GPU
cuda11
wheels 现在支持多个 cuDNN 版本。cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 wheel,因为它支持额外的功能。
cuDNN 8.0.5 或更高版本。
重大更改
GPU jaxlib 的安装命令如下
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (2021年10月12日)#
重大更改
现在,
jax.pmap
的静态参数必须是可哈希的。不可哈希的静态参数长期以来在
jax.jit
上是被禁止的,但它们仍然在jax.pmap
上被允许;jax.pmap
使用对象标识比较不可哈希的静态参数。这种行为是一个潜在的陷阱,因为使用对象标识比较参数会导致每次对象标识更改时都重新编译。相反,我们现在禁止不可哈希的参数:如果
jax.pmap
的用户想要通过对象标识比较静态参数,他们可以在他们的对象上定义__hash__
和__eq__
方法来做到这一点,或者将他们的对象包装在一个具有这些操作和对象标识语义的对象中。另一种选择是使用functools.partial
将不可哈希的静态参数封装到函数对象中。jax.util.partial
是一个意外的导出,现在已被移除。请改用 Python 标准库中的functools.partial
。
弃用
函数
jax.ops.index_update
,jax.ops.index_add
等已被弃用,并将在未来的 JAX 版本中移除。请使用 JAX 数组上的.at
属性 来代替,例如x.at[idx].set(y)
。目前,这些函数会产生DeprecationWarning
。
新功能
当使用 jaxlib 0.1.72 或更高版本时,改进
pmap
调度时间的优化 C++ 代码路径现在是默认设置。可以使用--experimental_cpp_pmap
标志(或JAX_CPP_PMAP
环境变量)禁用此功能。jax.numpy.unique
现在支持可选的fill_value
参数 (#8121)
jaxlib 0.1.72 (2021年10月12日)#
重大更改
已停止支持 CUDA 10.2 和 CUDA 10.1。Jaxlib 现在支持 CUDA 11.1+。
错误修复
修复了 https://github.com/jax-ml/jax/issues/7461,该问题由于 XLA 编译器内部不正确的缓冲区别名导致所有平台上的输出错误。
jax 0.2.21 (2021年9月23日)#
重大更改
jax.api
已被移除。作为jax.api.*
提供的函数是jax.*
中函数的别名;请改用jax.*
中的函数。jax.partial
和jax.lax.partial
是意外的导出,现在已被移除。请改用 Python 标准库中的functools.partial
。布尔标量索引现在会引发
TypeError
;以前这会静默地返回错误的结果 (#7925)。更多
jax.numpy
函数现在需要类数组输入,如果传递列表将报错 (#7747 #7802 #7907)。有关此更改背后的原理的讨论,请参阅 #7737。当在
jax.jit
等转换内部时,jax.numpy.array
始终将其生成的数组 staged 到 traced 计算中。以前,即使在jax.jit
装饰器下,jax.numpy.array
有时也会生成 on-device 数组。此更改可能会破坏使用 JAX 数组执行形状或索引计算的代码,这些计算必须是静态已知的;解决方法是使用经典的 NumPy 数组来执行此类计算。jnp.ndarray
现在是 JAX 数组的真正基类。特别是,这意味着对于标准 numpy 数组x
,isinstance(x, jnp.ndarray)
现在将返回False
(#7927)。
新功能
添加了
jax.numpy.insert()
实现 (#7936)。
jax 0.2.20 (2021年9月2日)#
jaxlib 0.1.71 (2021年9月1日)#
重大更改
已停止支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。
jax 0.2.19 (2021年8月12日)#
重大更改
根据 弃用政策,已停止支持 NumPy 1.17。请升级到受支持的 NumPy 版本。
在 JAX 数组上的许多运算符的实现周围添加了
jit
装饰器。这加快了常见运算符(例如+
)的调度时间。此更改对大多数用户来说应该是基本透明的。但是,有一个已知的行为更改,即当大型整数常量直接传递给 JAX 运算符时(例如,
x + 2**40
),现在可能会产生错误。解决方法是将常量强制转换为显式类型(例如,np.float64(2**40)
)。
新功能
改进了 jax2tf 中对于需要在数组计算中使用维度大小的操作(例如
jnp.mean
)的形状多态性的支持。 (#7317)
错误修复
修复了之前版本中的一些泄漏 trace 错误 (#7613)
jaxlib 0.1.70 (2021年8月9日)#
jax 0.2.18 (2021年7月21日)#
重大更改
根据 弃用政策,已停止支持 Python 3.6。请升级到受支持的 Python 版本。
最低 jaxlib 版本现在为 0.1.69。
已移除
jax.dlpack.from_dlpack()
的backend
参数。
新功能
添加了极分解 (
jax.scipy.linalg.polar()
)。
错误修复
收紧了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会与无效的
axis
值或空归约维度一起使用。 (#7196)
jaxlib 0.1.69 (2021年7月9日)#
修复了 TFRT CPU 后端中导致结果不正确的错误。
jax 0.2.17 (2021年7月9日)#
错误修复
对于 jaxlib <= 0.1.68,默认使用较旧的 “stream_executor” CPU 运行时,以解决 #7229,该问题由于并发问题导致 CPU 上的输出错误。
新功能
新的 SciPy 函数
jax.scipy.special.sph_harm()
。反向模式自动微分函数 (
jax.grad()
,jax.value_and_grad()
,jax.vjp()
, 和jax.linear_transpose()
) 支持一个参数,该参数指示在反向传播中应求和哪些命名轴(如果在正向传播中对它们进行了广播)。这使得这些 API 可以在映射内部以非逐示例的方式使用(最初仅限jax.experimental.maps.xmap()
)(#6950)。
jax 0.2.16 (June 23 2021)#
jax 0.2.15 (June 23 2021)#
jaxlib 0.1.68 (June 23 2021)#
错误修复
修复了 TFRT CPU 后端中的一个错误,该错误会在将 TPU 缓冲区传输到 CPU 时导致 nans。
jax 0.2.14 (June 10 2021)#
新功能
jax2tf.convert()
现在支持pjit
和sharded_jit
。新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯。
现在默认在足够新的 IPython 版本中启用使用
__tracebackhide__
的新回溯过滤模式。jax2tf.convert()
即使在算术运算中使用未知维度,也支持形状多态性,例如,jnp.reshape(-1)
(#6827)。jax2tf.convert()
在 TF ops 中生成带有位置信息的自定义属性。在 jax2tf 之后,XLA 生成的代码与 JAX/XLA 具有相同的位置信息。新的 SciPy 函数
jax.scipy.special.lpmn()
。
错误修复
jaxlib 0.1.67 (May 17 2021)#
jaxlib 0.1.66 (May 11 2021)#
新功能
CUDA 11.1 wheels 现在在所有 CUDA 11 版本 11.1 或更高版本上都受支持。
NVidia 现在承诺 CUDA 次要版本之间从 CUDA 11.1 开始兼容。这意味着 JAX 可以发布一个与 CUDA 11.2 和 11.3 兼容的 CUDA 11.1 wheel。
不再有针对 CUDA 11.2(或更高版本)的单独 jaxlib 版本;对于这些版本,请使用 CUDA 11.1 wheel (cuda111)。
Jaxlib 现在在 CUDA wheels 中捆绑了
libdevice.10.bc
。应该不再需要将 JAX 指向 CUDA 安装以查找此文件。为
jit()
实现添加了对静态关键字参数的自动支持。添加了对预转换异常跟踪的支持。
初步支持从
jit()
转换后的计算中修剪未使用的参数。修剪仍在进行中。改进了
PyTreeDef
对象的字符串表示。添加了对 XLA 的 variadic ReduceWindow 的支持。
错误修复
修复了当大量参数传递给计算时,远程云 TPU 支持中的一个错误。
修复了一个错误,该错误意味着 JAX 垃圾回收不会被
jit()
转换后的函数触发。
jax 0.2.13 (May 3 2021)#
新功能
当与 jaxlib 0.1.66 结合使用时,
jax.jit()
现在支持静态关键字参数。添加了一个新的static_argnames
选项来将关键字参数指定为静态参数。jax.nonzero()
有一个新的可选size
参数,允许它在jit
中使用 (#6501)jax.numpy.unique()
现在支持axis
参数 (#6532)。jax.experimental.host_callback.call()
现在支持pjit.pjit
(#6569)。添加了
jax.scipy.linalg.eigh_tridiagonal()
,用于计算三对角矩阵的特征值。目前仅支持特征值。异常中过滤和未过滤的堆栈跟踪的顺序已更改。附加到从 JAX 转换的代码抛出的异常的回溯现在被过滤,其中
UnfilteredStackTrace
异常包含原始跟踪作为过滤异常的__cause__
。过滤后的堆栈跟踪现在也适用于 Python 3.6。如果由反向模式自动微分转换的代码抛出异常,JAX 现在尝试将
JaxStackTraceBeforeTransformation
对象作为异常的__cause__
附加,该对象包含在前向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
重大更改
以下函数名称已更改。仍然存在别名,因此这不应破坏现有代码,但别名最终将被删除,因此请更改您的代码。
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
类似地,
local_devices()
的参数已从host_id
重命名为process_index
。除函数之外的
jax.jit()
的参数现在标记为仅关键字参数。此更改是为了防止在向jit
添加参数时意外中断。
错误修复
jaxlib 0.1.65 (April 7 2021)#
jax 0.2.12 (April 1 2021)#
新功能
新的性能分析 API:
jax.profiler.start_trace()
、jax.profiler.stop_trace()
和jax.profiler.trace()
jax.lax.reduce()
现在是可微分的。
重大更改
最低 jaxlib 版本现在为 0.1.64。
一些性能分析器 API 名称已更改。仍然存在别名,因此这不应破坏现有代码,但别名最终将被删除,因此请更改您的代码。
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
Omnistaging 不再可以禁用。有关更多信息,请参阅 omnistaging。
大于最大
int64
值的 Python 整数现在在所有情况下都会导致溢出,而不是在某些情况下静默转换为uint64
(#6047)。在 X64 模式之外,超出
int32
可表示范围的 Python 整数现在将导致OverflowError
,而不是静默截断其值。
错误修复
host_callback
现在支持参数和结果中的空数组 (#6262)。jax.random.randint()
裁剪而不是环绕超出范围的限制,现在可以生成指定 dtype 的完整范围内的整数 (#5868)
jax 0.2.11 (March 23 2021)#
新功能
错误修复
重大更改
最低 jaxlib 版本现在为 0.1.62。
jaxlib 0.1.64 (March 18 2021)#
jaxlib 0.1.63 (March 17 2021)#
jax 0.2.10 (March 5 2021)#
新功能
jax.scipy.stats.chi2()
现在作为具有 logpdf 和 pdf 方法的分布提供。jax.scipy.stats.betabinom()
现在作为具有 logpmf 和 pmf 方法的分布提供。添加了
jax.experimental.jax2tf.call_tf()
以从 JAX 调用 TensorFlow 函数 (#5627) 和 README)。扩展了
lax.pad
的批处理规则,以支持填充值的批处理。
错误修复
jax.numpy.take()
正确处理负索引 (#5768)
重大更改
调整了 JAX 的类型提升规则,以使类型提升更一致且对 JIT 不变。特别是,在适当的情况下,二元运算现在可以产生弱类型值。更改的主要用户可见效果是,某些操作产生的结果精度与以前不同;例如,表达式
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
之前返回float64
数组,现在返回bfloat16
数组。JAX 的类型提升行为在 类型提升语义 中描述。jax.numpy.linspace()
现在计算整数值的下限,即向 -inf 而不是 0 舍入。此更改是为了与 NumPy 1.20.0 匹配。jax.numpy.i0()
不再接受复数。以前,该函数计算复数参数的绝对值。此更改是为了与 NumPy 1.20.0 的语义匹配。几个
jax.numpy
函数不再接受元组或列表来代替数组参数:jax.numpy.pad()
、:funcjax.numpy.ravel
、jax.numpy.repeat()
、jax.numpy.reshape()
。通常,jax.numpy
函数应与标量或数组参数一起使用。
jaxlib 0.1.62 (March 9 2021)#
新功能
默认情况下,现在构建 jaxlib wheels 以要求 x86-64 机器上的 AVX 指令。如果您想在不支持 AVX 的机器上使用 JAX,则可以使用
--target_cpu_features
标志使用build.py
从源代码构建 jaxlib。--target_cpu_features
也取代了--enable_march_native
。
jaxlib 0.1.61 (February 12 2021)#
jaxlib 0.1.60 (February 3 2021)#
错误修复
修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏。内存泄漏存在于 jaxlib 版本 0.1.58 和 0.1.59 中。
bool
、int8
和uint8
现在被认为可以安全地转换为bfloat16
NumPy 扩展类型。
jax 0.2.9 (January 26 2021)#
新功能
扩展
jax.experimental.loops
模块,以支持 pytrees。改进了错误检查和错误消息。添加
jax.experimental.enable_x64()
和jax.experimental.disable_x64()
。这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
重大更改
jax.ops.segment_sum()
现在删除超出范围的段 ID,而不是将它们包装到段 ID 空间中。这样做是为了性能原因。
jaxlib 0.1.59 (January 15 2021)#
jax 0.2.8 (January 12 2021)#
新功能
添加
jax.closure_convert()
以用于高阶自定义导数函数。 (#5244)添加
jax.experimental.host_callback.call()
以在主机上调用自定义 Python 函数,并将结果返回到设备计算。 (#5243)
错误修复
重大更改
jax.numpy.pad
现在接受关键字参数。位置参数constant_values
已被删除。此外,传递不受支持的关键字参数会引发错误。针对
jax.experimental.host_callback.id_tap()
的更改 (#5243)删除了对
jax.experimental.host_callback.id_tap()
的kwargs
的支持。(此支持已弃用几个月。)更改了
jax.experimental.host_callback.id_print()
的元组打印,以使用“(”而不是“[“。更改了存在 JVP 时
jax.experimental.host_callback.id_print()
以打印原始值和切线的对。以前,原始值和切线有两个单独的打印操作。host_callback.outfeed_receiver
已被删除(它不是必需的,并且在几个月前已弃用)。
新功能
用于调试
inf
的新标志,类似于NaN
的标志 (#5224)。
jax 0.2.7 (Dec 4 2020)#
新功能
添加
jax.device_put_replicated
为
jax.experimental.sharded_jit
添加多主机支持为微分
jax.numpy.linalg.eig
计算的特征值添加支持为在 Windows 平台上构建添加支持
在
jax.pmap
中为通用 in_axes 和 out_axes 添加支持为
jax.numpy.linalg.slogdet
添加复数支持
错误修复
修复了零点处
jax.numpy.sinc
的高于二阶的导数修复了转置规则中围绕符号零的一些难以命中的错误
重大更改
jax.experimental.optix
已被删除,取而代之的是独立的optax
Python 包。使用非元组序列索引 JAX 数组现在会引发
TypeError
。自 v1.16 起,Numpy 中已弃用此类型的索引,自 v0.2.4 起,JAX 中也已弃用。请参阅 #4564。
jax 0.2.6 (Nov 18 2020)#
新功能
为 jax.experimental.jax2tf 转换器添加了形状多态跟踪的支持。请参阅 README.md。
重大更改清理
在 jax.jit 和 xla_computation 的不可哈希静态参数上引发错误。请参阅 cb48f42。
提高类型提升行为的一致性 (#4744)
将复数 Python 标量添加到 JAX 浮点数时,会尊重 JAX 浮点数的精度。例如,
jnp.float32(1) + 1j
现在返回complex64
,而之前返回complex128
。涉及 uint64、有符号 int 和第三种类型的 3 个或更多项的类型提升结果现在独立于参数的顺序。例如:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
都返回float16
,而之前第一个返回float64
,第二个返回float16
。
(未记录的)
jax.lax_linalg
线性代数模块的内容现在公开为jax.lax.linalg
。jax.random.PRNGKey
现在在 JIT 编译内外产生相同的结果 (#4877)。这需要在一些特定情况下更改给定种子的结果使用
jax_enable_x64=False
,作为 Python 整数传递的负种子现在在 JIT 模式之外返回不同的结果。例如,jax.random.PRNGKey(-1)
之前返回[4294967295, 4294967295]
,现在返回[0, 4294967295]
。这与 JIT 中的行为相匹配。超出
int64
可表示范围的种子在 JIT 之外现在会导致OverflowError
而不是TypeError
。这与 JIT 中的行为相匹配。
要恢复以前为负整数返回的键,其中
jax_enable_x64=False
在 JIT 之外,您可以使用key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray 现在在尝试访问其值但在已被删除时,引发
RuntimeError
而不是ValueError
。
jaxlib 0.1.58 (January 12ish 2021)#
修复了一个错误,该错误意味着 JAX 有时返回平台特定的类型(例如,
np.cint
)而不是标准类型(例如,np.int32
)。 (#4903)修复了常量折叠某些 int16 操作时的崩溃。 (#4971)
为
pytree.flatten()
添加了is_leaf
谓词。
jaxlib 0.1.57 (November 12 2020)#
修复了 GPU wheels 中的 manylinux2010 合规性问题。
将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
修复了 bfloat16 值的哈希未正确初始化且可能更改的错误 (#4651)。
添加了在将数组传递给 DLPack 时保留所有权的支持 (#4636)。
修复了批处理三角解的错误,其大小大于 128 但不是 128 的倍数。
修复了在多个 GPU 上执行并发 FFT 的错误 (#3518)。
修复了 profiler 中工具丢失的错误 (#4427)。
放弃了对 CUDA 10.0 的支持。
jax 0.2.5 (October 27 2020)#
改进
确保
check_jaxpr
不执行 FLOPS。请参阅 #4650。扩展了 jax2tf 转换的 JAX primitives 集。请参阅 primitives_with_limited_support.md。
jax 0.2.4 (October 19 2020)#
jaxlib 0.1.56 (October 14, 2020)#
jax 0.2.3 (October 14 2020)#
之所以这么快发布另一个版本,是因为我们需要暂时回滚一个新的 jit fastpath,同时我们研究性能下降问题
jax 0.2.2 (October 13 2020)#
jax 0.2.1 (October 6 2020)#
改进
作为 omnistaging 的好处,即使
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
的结果未在计算中使用,host_callback 函数也会按程序顺序执行。
jax (0.2.0) (September 23 2020)#
改进
默认启用 Omnistaging。请参阅 #3370 和 omnistaging
jax (0.1.77) (September 15 2020)#
重大更改
jax.experimental.host_callback.id_tap()
的新的简化接口 (#4101)
jaxlib 0.1.55 (September 8, 2020)#
更新 XLA
修复 DLPackManagedTensorToBuffer 中的错误 (#4196)
jax 0.1.76 (September 8, 2020)#
jax 0.1.75 (July 30, 2020)#
Bug 修复
使 jnp.abs() 适用于无符号输入 (#3914)
改进
在标志后添加了“Omnistaging”行为,默认情况下禁用 (#3370)
jax 0.1.74 (July 29, 2020)#
新功能
BFGS (#3101)
TPU 支持半精度算术 (#3878)
Bug 修复
防止一些意外的 dtype 警告 (#3874)
修复了自定义导数中的多线程错误 (#3845, #3869)
改进
更快的 searchsorted 实现 (#3873)
更好的 jax.numpy 排序算法的测试覆盖率 (#3836)
jaxlib 0.1.52 (July 22, 2020)#
更新 XLA。
jax 0.1.73 (July 22, 2020)#
最低 jaxlib 版本现在为 0.1.51。
新功能
jax.image.resize. (#3703)
hfft 和 ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
和scan
primitive 支持unroll
参数,用于在降低到 XLA 时进行循环展开 (#3738)。
Bug 修复
修复了 reduction 重复轴错误 (#3618)
修复了 lax.pad 的形状规则,用于大小为 0 的输入维度。 (#3608)
使 psum 转置处理零余切 (#3653)
修复了在对大小为 0 轴进行 reduce-prod 的 JVP 时出现的形状错误。 (#3729)
支持通过 jax.lax.all_to_all 进行微分 (#3733)
解决 jax.scipy.special.zeta 中的 nan 问题 (#3777)
改进
jax2tf 的多项改进
使用单次变参归约重新实现 argmin/argmax。 (#3611)
默认启用 XLA SPMD 分区。 (#3151)
添加对 0 维转置卷积的支持 (#3643)
使 LU 梯度适用于低秩矩阵 (#3610)
在 jet 中支持 multiple_results 和自定义 JVP (#3657)
推广 reduce-window 填充以支持 (lo, hi) 对。 (#3728)
在 CPU 和 GPU 上实现复数卷积。 (#3735)
使 jnp.take 适用于空数组的空切片。 (#3751)
放宽 dot_general 的维度顺序规则。 (#3778)
为 GPU 启用缓冲区捐赠。 (#3800)
为 reduce window op 添加对基础扩张和窗口扩张的支持… (#3803)
jaxlib 0.1.51 (2020 年 7 月 2 日)#
更新 XLA。
为 host_callback 添加新的运行时支持。
jax 0.1.72 (2020 年 6 月 28 日)#
jax 0.1.71 (2020 年 6 月 25 日)#
jaxlib 0.1.50 (2020 年 6 月 25 日)#
添加对 CUDA 11.0 的支持。
放弃对 CUDA 9.2 的支持(我们仅维护对最近四个 CUDA 版本的支持。)
更新 XLA。
jaxlib 0.1.49 (2020 年 6 月 19 日)#
错误修复
修复可能导致编译缓慢的构建问题 (tensorflow/tensorflow)
jaxlib 0.1.48 (2020 年 6 月 12 日)#
新功能
添加对快速回溯收集的支持。
添加对设备上堆分析的初步支持。
为
bfloat16
类型实现np.nextafter
。CPU 和 GPU 上 FFT 的 Complex128 支持。
错误修复
改进了 GPU 上 float64
tanh
的精度。GPU 上的 float64 散点操作更快。
CPU 上的复数矩阵乘法应该更快。
CPU 上的稳定排序现在应该真正稳定了。
CPU 后端中的并发错误修复。
jax 0.1.70 (2020 年 6 月 8 日)#
jax 0.1.69 (2020 年 6 月 3 日)#
jax 0.1.68 (2020 年 5 月 21 日)#
jax 0.1.67 (2020 年 5 月 12 日)#
新功能
支持使用
axis_index_groups
对 pmapped 轴的子集进行归约 #2382。从编译代码中打印和调用主机端 Python 函数的实验性支持。请参阅 id_print 和 id_tap (#3006)。
值得注意的更改
已收紧从
jax.numpy
导出的名称的可见性。这可能会破坏以前意外使用导出的名称的代码。
jaxlib 0.1.47 (2020 年 5 月 8 日)#
修复了出队时的崩溃。
jax 0.1.66 (2020 年 5 月 5 日)#
jaxlib 0.1.46 (2020 年 5 月 5 日)#
修复了 Mac OS X 上线性代数函数的崩溃问题 (#432)。
修复了当操作系统或虚拟机监控程序禁用 AVX512 指令时,使用 AVX512 指令导致的非法指令崩溃 (#2906)。
jax 0.1.65 (2020 年 4 月 30 日)#
jaxlib 0.1.45 (2020 年 4 月 21 日)#
修复了段错误:#2755
将 Sort HLO 上的 is_stable 选项引入到 Python。
jax 0.1.64 (2020 年 4 月 21 日)#
新功能
为函数式索引更新添加语法糖 #2684。
为
jax.experimental.jet()
添加更多原始规则。
错误修复
更好的错误信息
改进了
lax.while_loop()
的反向模式微分的错误消息 #2129。
jaxlib 0.1.44 (2020 年 4 月 16 日)#
修复了当存在多个不同型号的 GPU 时,JAX 只会编译适用于第一个 GPU 的程序的问题。
修复了
batch_group_count
卷积的错误。为更多 GPU 版本添加了预编译的 SASS,以避免启动 PTX 编译挂起。
jax 0.1.63 (2020 年 4 月 12 日)#
从 #2026 添加了
jax.custom_jvp
和jax.custom_vjp
,请参阅 教程笔记本。弃用了jax.custom_transforms
并从文档中删除(尽管它仍然有效)。添加
scipy.sparse.linalg.cg
#2566。更改了 Tracers 的打印方式,以显示更多有用的调试信息 #2591。
使
jax.numpy.isclose
正确处理nan
和inf
#2501。为
jax.experimental.jet
添加了几个新规则 #2537。修复了未提供
scale
/center
时的jax.experimental.stax.BatchNorm
。修复了
jax.numpy.einsum
中一些缺失的广播情况 #2512。根据并行前缀扫描实现
jax.numpy.cumsum
和jax.numpy.cumprod
#2596,并使reduce_prod
可微分至任意阶 #2597。将
batch_group_count
添加到conv_general_dilated
#2635。为
test_util.check_grads
添加文档字符串 #2656。添加
callback_transform
#2665。实现
rollaxis
、convolve
/correlate
1d & 2d、copysign
、trunc
、roots
和quantile
/percentile
插值选项。
jaxlib 0.1.43 (2020 年 3 月 31 日)#
修复了 GPU 上 Resnet-50 的性能回归问题。
jax 0.1.62 (2020 年 3 月 21 日)#
JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更高版本。
删除了内部函数
lax._safe_mul
,该函数实现了约定0. * nan == 0.
。此更改意味着某些程序在微分时会产生 nan,而以前会产生正确的值,尽管它可以确保为其他程序生成 nan 而不是静默地生成不正确的结果。有关详细信息,请参阅 #2447 和 #1052。添加了
all_gather
并行便利函数。核心代码中添加了更多类型注释。
jaxlib 0.1.42 (2020 年 3 月 19 日)#
jaxlib 0.1.41 由于 API 不兼容而破坏了云 TPU 支持。此版本再次修复了它。
JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更高版本。
jax 0.1.61 (2020 年 3 月 17 日)#
修复了 Python 3.5 支持。这将是最后一个支持 Python 3.5 的 JAX 或 jaxlib 版本。
jax 0.1.60 (2020 年 3 月 17 日)#
新功能
jax.pmap()
具有static_broadcast_argnums
参数,该参数允许用户指定应视为编译时常量并应广播到所有设备的参数。它的工作方式类似于jax.jit()
中的static_argnums
。改进了当追踪器错误地保存在全局状态时的错误消息。
添加了
jax.nn.one_hot()
实用函数。添加了
jax.experimental.jet
,用于指数级更快的更高阶自动微分。为
jax.lax.broadcast_in_dim()
的参数添加了更多正确性检查。
最低 jaxlib 版本现在为 0.1.41。
jaxlib 0.1.40 (2020 年 3 月 4 日)#
在 Jaxlib 中为 TensorFlow profiler 添加了实验性支持,这允许从 TensorBoard 跟踪 CPU 和 GPU 计算。
包括通过 NCCL 通信的多主机 GPU 计算的原型支持。
提高了 GPU 上 NCCL 集体操作的性能。
添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。
支持 XLA 编译时已知的设备分配。
jax 0.1.59 (2020 年 2 月 11 日)#
重大更改
最低 jaxlib 版本现在为 0.1.38。
通过删除
Jaxpr.freevars
和Jaxpr.bound_subjaxprs
简化了Jaxpr
。调用原语(xla_call
、xla_pmap
、sharded_call
和remat_call
)获得了一个新的参数call_jaxpr
,其中包含一个完全封闭的(没有constvars
)jaxpr。此外,还为原语添加了一个新的字段call_primitive
。
新功能
lax.cond
的反向模式自动微分(例如grad
),使其现在在两种模式下都可微分 (#2091)JAX 现在支持 DLPack,它允许与其他库(如 PyTorch)以零拷贝方式共享 CPU 和 GPU 数组。
JAX GPU DeviceArrays 现在支持
__cuda_array_interface__
,这是另一种零拷贝协议,用于与其他库(如 CuPy 和 Numba)共享 GPU 数组。JAX CPU 设备缓冲区现在实现了 Python 缓冲区协议,该协议允许 JAX 和 NumPy 之间进行零拷贝缓冲区共享。
添加了 JAX_SKIP_SLOW_TESTS 环境变量以跳过已知速度较慢的测试。
jaxlib 0.1.39 (2020 年 2 月 11 日)#
更新 XLA。
jaxlib 0.1.38 (2020 年 1 月 29 日)#
不再支持 CUDA 9.0。
CUDA 10.2 wheels 现在默认构建。
jax 0.1.58 (2020 年 1 月 28 日)#
重大更改
JAX 已放弃 Python 2 支持,因为 Python 2 于 2020 年 1 月 1 日达到生命周期结束。请更新到 Python 3.5 或更高版本。
新功能
while 循环的前向模式自动微分 (
jvp
) (#1980)
新的 NumPy 和 SciPy 函数
GPU 上的批量 Cholesky 分解现在使用更高效的批量内核。
值得注意的错误修复#
随着 Python 3 的升级,JAX 不再依赖
fastcache
,这应该有助于安装。