变更日志#
最好在此处查看。 有关特定于实验性 Pallas API 的更改,请参阅 Pallas 变更日志。
JAX 遵循基于努力的版本控制; 有关此版本控制和 JAX 的 API 兼容性策略的讨论,请参阅 API 兼容性。 有关 Python 和 NumPy 版本支持策略,请参阅 Python 和 NumPy 版本支持策略。
未发布#
弃用
JAX 0.7.0(2025 年 7 月 22 日)#
新功能
添加了
jax.P
,它是jax.sharding.PartitionSpec
的别名。
重大更改
JAX 默认从 GSPMD 迁移到 Shardy。 有关更多信息,请参阅迁移指南。
JAX autodiff 默认切换到使用直接线性化(而不是通过 JVP 和部分求值来实现线性化)。 有关更多信息,请参阅 迁移指南。
jax.stages.OutInfo
已替换为jax.ShapeDtypeStruct
。jax.jit()
现在要求通过位置传递fun
,并通过关键字传递其他参数。 从 v0.7.x 开始,否则会导致错误。 这在 v0.6.x 中引发了 DeprecationWarning。最低 Python 版本现在是 3.11。 3.11 将保持最低支持版本,直到 2026 年 7 月。
布局 API 重命名
Layout
、.layout
、.input_layouts
和.output_layouts
已重命名为Format
、.format
、.input_formats
和.output_formats
DeviceLocalLayout
、.device_local_layout
已重命名为Layout
和.layout
jax.experimental.shard
模块已被删除,所有 API 都已移动到jax.sharding
端点。 因此,请使用jax.sharding.reshard
、jax.sharding.auto_axes
和jax.sharding.explicit_axes
代替其实验性端点。在 JAX 0.6 中弃用后,删除了
lax.infeed
和lax.outfeed
。transfer_to_infeed
和transfer_from_outfeed
方法也已从Device
对象中删除。jax.extend.core.primitives.pjit_p
原语已重命名为jit_p
,其name
属性已从"pjit"
更改为"jit"
。 这会影响 jaxpr 的字符串表示形式。 同一个原语不再从jax.experimental.pjit
模块导出。删除了(未记录)函数
jax.extend.backend.add_clear_backends_callback
。 用户应改用jax.extend.backend.register_backend_cache
。
弃用
jax.dlpack.SUPPORTED_DTYPES
已弃用; 请使用新的jax.dlpack.is_supported_dtype()
函数。按照 SciPy 中的类似弃用,已弃用
jax.scipy.special.sph_harm()
; 请改用jax.scipy.special.sph_harm_y()
。从
jax.interpreters.xla
中,删除了之前弃用的符号abstractify
和pytype_aval_mappings
。jax.interpreters.xla.canonicalize_dtype()
已弃用。 对于规范化 dtype,请首选jax.dtypes.canonicalize_dtype()
。 对于检查对象是否是有效的 jax 输入,请首选jax.core.valid_jaxtype()
。从
jax.core
中,删除了之前弃用的符号AxisName
、ConcretizationTypeError
、axis_frame
、call_p
、closed_call_p
、get_type
、trace_state_clean
、typematch
和typecheck
。从
jax.lib.xla_client
中,删除了之前弃用的符号DeviceAssignment
、get_topology_for_devices
和mlir_api_version
。在 v0.5.0 中弃用后,删除了
jax.extend.ffi
。 请改用jax.ffi
。jax.lib.xla_bridge.get_compile_options()
已弃用,并替换为jax.extend.backend.get_compile_options()
。
JAX 0.6.2(2025 年 6 月 17 日)#
新功能
添加了
jax.tree.broadcast()
,它实现了 pytree 前缀广播助手。
更改
最低 NumPy 版本为 1.26,最低 SciPy 版本为 1.12。
JAX 0.6.1(2025 年 5 月 21 日)#
新功能
添加了
jax.lax.axis_size()
,它返回给定名称的映射轴的大小。
更改
重新启用了 CUDA 包依赖项版本的额外检查,该检查在之前的版本中被意外禁用。
JAX 每晚构建包现在发布到构件注册表。 若要安装这些包,请参阅 JAX 安装指南。
jax.sharding.PartitionSpec
不再继承自元组。jax.ShapeDtypeStruct
现在是不可变的。 请使用.update
方法更新您的ShapeDtypeStruct
,而不是进行就地更新。
弃用
jax.custom_derivatives.custom_jvp_call_jaxpr_p
已弃用,将在 JAX v0.7.0 中删除。
JAX 0.6.0(2025 年 4 月 16 日)#
重大更改
jax.numpy.array()
不再接受None
。 此行为自 2023 年 11 月起已弃用,现在已删除。删除了
config.jax_data_dependent_tracing_fallback
配置选项,该选项在 v0.4.36 中临时添加,以允许用户选择退出新的“无堆栈”跟踪机制。删除了
config.jax_eager_pmap
配置选项。禁止在应用后续包装器的情况下,在
jax.jit
的结果上调用lower
和trace
AOT API。 之前这样做是可行的,但会默默地忽略包装器。 解决方法是在包装器中最后应用jax.jit
,对于jax.pmap
也是如此。 请参阅 #27873。删除了
jax
的cuda12_pip
extra; 请改用pip install jax[cuda12]
。
更改
最低 CuDNN 版本为 v9.8。
JAX 现在使用 CUDA 12.8 构建。 所有 CUDA 12.1 或更高版本仍然受支持。
JAX 包 extras 现在已更新为使用短划线而不是下划线,以与 PEP 685 对齐。 例如,如果您之前使用
pip install jax[cuda12_local]
安装 JAX,请改为运行pip install jax[cuda12-local]
。jax.jit()
现在要求通过位置传递fun
,并通过关键字传递其他参数。 否则会导致 v0.6.X 中出现 DeprecationWarning,并从 v0.7.X 开始出现错误。
弃用
jax.tree_util.build_tree()
已弃用。 请改用jax.tree.unflatten()
。使用 XLA 的 FFI 为 CPU 和 GPU 设备实现了主机回调处理程序,并使用 XLA 的自定义调用删除了现有的 CPU/GPU 处理程序。
jax.lib.xla_extension
中的所有 API 现在都已弃用。jax.interpreters.mlir.hlo
和jax.interpreters.mlir.func_dialect
是意外导出的,已被删除。 如果需要,它们可以从jax.extend.mlir
中获得。jax.interpreters.mlir.custom_call
已弃用。 应改用jax.ffi
提供的 API。不再支持将
jax.ffi.ffi_call()
与内联参数一起使用的已弃用方法。ffi_call()
现在无条件地返回一个可调用对象。弃用了
jax.lib.xla_client
中的以下导出:get_topology_for_devices
、heap_profile
、mlir_api_version
、Client
、CompileOptions
、DeviceAssignment
、Frame
、HloSharding
、OpSharding
、Traceback
。弃用了
jax.util
中的以下内部 API:HashableFunction
、as_hashable_function
、cache
、safe_map
、safe_zip
、split_dict
、split_list
、split_list_checked
、split_merge
、subvals
、toposort
、unzip2
、wrap_name
和wraps
。jax.dlpack.to_dlpack
已弃用。 您通常可以将 JAXArray
直接传递给另一个框架的from_dlpack
函数。 如果您需要to_dlpack
的功能,请使用数组的__dlpack__
属性。jax.lax.infeed
、jax.lax.infeed_p
、jax.lax.outfeed
和jax.lax.outfeed_p
已弃用,将在 JAX v0.7.0 中删除。删除了几个先前弃用的 API,包括
从
jax.lib.xla_client
中:ArrayImpl
、FftType
、PaddingType
、PrimitiveType
、XlaBuilder
、dtype_to_etype
、ops
、register_custom_call_target
、shape_from_pyval
、Shape
、XlaComputation
。从
jax.lib.xla_extension
中:ArrayImpl
、XlaRuntimeError
。从
jax
中:jax.treedef_is_leaf
、jax.tree_flatten
、jax.tree_map
、jax.tree_leaves
、jax.tree_structure
、jax.tree_transpose
和jax.tree_unflatten
。 替换项可以在jax.tree
或jax.tree_util
中找到。来自
jax.core
:AxisSize
、ClosedJaxpr
、EvalTrace
、InDBIdx
、InputType
、Jaxpr
、JaxprEqn
、Literal
、MapPrimitive
、OpaqueTraceState
、OutDBIdx
、Primitive
、Token
、TRACER_LEAK_DEBUGGER_WARNING
、Var
、concrete_aval
、dedup_referents
、escaped_tracer_error
、extend_axis_env_nd
、full_lower
、get_referent
、jaxpr_as_fun
、join_effects
、lattice_join
、leaked_tracer_error
、maybe_find_leaked_tracers
、raise_to_shaped
、raise_to_shaped_mappings
、reset_trace_state
、str_eqn_compact
、substitute_vars_in_output_ty
、typecompat
和used_axis_names_jaxpr
。大多数没有公共替代品,但有些可在jax.extend.core
中找到。pure_callback()
和ffi_call()
的vectorized
参数。请改用vmap_method
参数。
jax 0.5.3(2025 年 3 月 19 日)#
新特性
为
jax.lax.dynamic_slice()
、jax.lax.dynamic_update_slice()
和相关函数添加了allow_negative_indices
选项。默认值为 true,与当前行为匹配。如果设置为 false,JAX 不需要发出钳制负索引的代码,从而提高代码大小。为
jax.random.categorical()
添加了replace
选项,以启用不放回采样。
jax 0.5.2(2025 年 3 月 4 日)#
0.5.1 的补丁版本
Bug 修复
修复了 TPU 指标记录和
tpu-info
,它们在 0.5.1 中被破坏
jax 0.5.1(2025 年 2 月 24 日)#
重大更改
jit 跟踪缓存现在以输入 NamedShardings 为键。以前,跟踪缓存根本不包含分片信息(尽管后续的 jit 缓存确实包含,例如降级和编译缓存),因此两种不同类型的等效分片不会重新跟踪,但现在它们会重新跟踪。例如
@jax.jit def f(x): return x # inp1.sharding is of type SingleDeviceSharding inp1 = jnp.arange(8) f(inp1) mesh = jax.make_mesh((1,), ('x',)) # inp2.sharding is of type NamedSharding inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x'))) f(inp2) # tracing cache miss
在上面的示例中,调用
f(inp1)
然后调用f(inp2)
将导致跟踪缓存未命中,因为在跟踪时分片在抽象值上发生了更改。
新特性
添加了一个实验性的
jax.experimental.custom_dce.custom_dce()
装饰器,以支持在 JAX 级别死代码消除 (DCE) 下自定义不透明函数的行为。有关更多详细信息,请参阅 #25956。在
jax.lax
中添加了底层归约 API:jax.lax.reduce_sum()
、jax.lax.reduce_prod()
、jax.lax.reduce_max()
、jax.lax.reduce_min()
、jax.lax.reduce_and()
、jax.lax.reduce_or()
和jax.lax.reduce_xor()
。jax.lax.linalg.qr()
和jax.scipy.linalg.qr()
现在支持在 CPU 和 GPU 上进行列旋转。请参阅 #20282 和添加了
jax.random.multinomial()
。#25955 了解更多详细信息。
更改
JAX_CPU_COLLECTIVES_IMPLEMENTATION
和JAX_NUM_CPU_DEVICES
现在可以用作环境变量。以前,它们只能通过 jax.config 或标志指定。JAX_CPU_COLLECTIVES_IMPLEMENTATION
现在默认为'gloo'
,这意味着多进程 CPU 通信可以开箱即用。jax[tpu]
TPU 扩展不再依赖libtpu-nightly
包。如果您的机器上存在此包,可以安全地将其删除;JAX 现在改用libtpu
。
弃用
内部函数
linear_util.wrap_init
和构造函数core.Jaxpr
现在必须接受非空的core.DebugInfo
kwarg。在有限的时间内,如果使用jax.extend.linear_util.wrap_init
但没有调试信息,则会打印DeprecationWarning
。此操作的下游影响是,其他几个内部函数需要调试信息。此更改不会影响公共 API。有关更多详细信息,请参阅 https://github.com/jax-ml/jax/issues/26480。在
jax.numpy.ndim()
、jax.numpy.shape()
和jax.numpy.size()
中,非类数组输入(例如列表、元组等)现在已弃用。
Bug 修复
TPU 运行时启动和关闭时间在 TPU v5e 及更高版本上应得到显著改善(从大约 17 秒到大约 8 秒)。如果尚未设置,您可能需要在 VM 映像中启用透明大页(
sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'
)。我们希望在未来的版本中进一步改进这一点。如果未设置或设置为 -1,即如果未启用 LRU 逐出策略,则持久编译缓存不再写入访问时间文件。在使用具有大规模网络存储的缓存时,这应该可以提高性能。
jax 0.5.0(2025 年 1 月 17 日)#
从这个版本开始,JAX 现在使用 基于努力的版本控制。由于此版本对 PRNG 密钥语义进行了重大更改,可能需要用户更新其代码,因此我们正在增加 JAX 的“meso”版本以表示这一点。
重大更改
默认启用
jax_threefry_partitionable
(请参阅 更新说明)。此版本放弃了对 Mac x86 wheel 的支持。Mac ARM 当然仍然受支持。有关最近的讨论,请参阅 https://github.com/jax-ml/jax/discussions/22936。
以下两个关键因素促成了这一决定
Mac x86 构建(仅)存在许多测试失败和崩溃。我们宁愿不发布版本,也不愿发布一个损坏的版本。
Mac x86 硬件已停产,目前开发人员无法轻易获得。因此,即使我们想解决这种问题也很困难。
如果社区愿意帮助支持该平台,我们乐于重新添加对 Mac x86 的支持:特别是,我们需要 JAX 测试套件在 Mac x86 上干净地通过,然后我们才能再次发布版本。
更改
最低 NumPy 版本现在为 1.25。NumPy 1.25 将在 2025 年 6 月之前保持最低支持版本。
最低 SciPy 版本现在为 1.11。SciPy 1.11 将在 2025 年 6 月之前保持最低支持版本。
jax.numpy.einsum()
现在默认为optimize='auto'
而不是optimize='optimal'
。这避免了在多个参数的情况下指数级缩放跟踪时间 (#25214)。jax.numpy.linalg.solve()
不再支持右侧的批处理 1D 参数。要在这些情况下恢复以前的行为,请使用solve(a, b[..., None]).squeeze(-1)
。
新特性
jax.numpy.fft.fftn()
、jax.numpy.fft.rfftn()
、jax.numpy.fft.ifftn()
和jax.numpy.fft.irfftn()
现在支持超过 3 个维度的转换,这以前是限制。有关更多详细信息,请参阅 #25606。通过新的
jax.ffi.register_ffi_type_id()
函数添加了对 FFI 中用户定义状态的支持。AOT 降级
.as_text()
方法现在支持debug_info
选项以在输出中包含调试信息,例如,源位置。
弃用
来自
jax.interpreters.xla
,abstractify
和pytype_aval_mappings
现在已弃用,已被jax.core
中具有相同名称的符号替换。jax.scipy.special.lpmn()
和jax.scipy.special.lpmn_values()
已弃用,因为它们已在 SciPy v1.15.0 中弃用。没有计划用新的 API 替换这些已弃用的函数。jax.extend.ffi
子模块已移动到jax.ffi
,并且以前的导入路径已弃用。
删除
jax_enable_memories
标志已删除,并且该标志的行为默认开启。从
jax.lib.xla_client
中,先前已弃用的Device
和XlaRuntimeError
符号已被删除;请改用jax.Device
和jax.errors.JaxRuntimeError
。jax.experimental.array_api
模块已在 JAX v0.4.32 中弃用后被删除。自该版本以来,jax.numpy
直接支持数组 API。
jax 0.4.38(2024 年 12 月 17 日)#
重大更改
XlaExecutable.cost_analysis
现在返回dict[str, float]
(而不是单元素list[dict[str, float]]
)。
更改
添加了
jax.tree.flatten_with_path
和jax.tree.map_with_path
作为对应tree_util
函数的快捷方式。
弃用
内部
jax.core
命名空间中的许多 API 已弃用。大多数都是空操作、很少使用或可以通过jax.extend.core
中具有相同名称的 API 替换;有关这些半公共扩展的兼容性保证的信息,请参阅jax.extend
的文档。删除了几个先前弃用的 API,包括
来自
jax.core
:check_eqn
、check_type
、check_valid_jaxtype
和non_negative_dim
。来自
jax.lib.xla_bridge
:xla_client
和default_backend
。来自
jax.lib.xla_client
:_xla
和bfloat16
。来自
jax.numpy
:round_
。
新特性
jax.export.export()
可用于设备多态导出,其中分片使用jax.sharding.AbstractMesh()
构造。请参阅 jax.export 文档。添加了
jax.lax.split()
。这是jax.numpy.split()
的原始版本,添加的原因是在自动微分期间产生更紧凑的转置。
jax 0.4.37(2024 年 12 月 9 日)#
这是 jax 0.4.36 的补丁版本。在此版本中仅发布了“jax”。
Bug 修复
修复了一个错误,如果参数命名为
f
(#25329),则jit
会出错。修复了一个 bug,如果用户为 flatten 和 flatten_with_path 注册了具有不同辅助数据的 pytree 节点类,则会在
jax.lax.while_loop()
中抛出index out of range
错误。固定了一个新的 libtpu 版本 (0.0.6),该版本修复了 TPU v6e 上的编译器 bug。
jax 0.4.36(2024 年 12 月 5 日)#
重大更改
此版本发布了“stackless”,这是 JAX 跟踪机制的一个内部更改。我们使跟踪调度纯粹是上下文的函数,而不是上下文和数据的函数。这让我们删除了大量用于管理数据相关跟踪的机制:级别、子级别、
post_process_call
、new_base_main
、custom_bind
等。此更改应仅影响使用 JAX 内部结构的用户。如果您确实使用了 JAX 内部结构,则可能需要更新您的代码(请参阅 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f,了解有关如何执行此操作的线索)。使用 JAX 库也可能存在版本偏差问题。如果您发现此更改破坏了您不使用 JAX 内部结构的代码,请尝试使用
config.jax_data_dependent_tracing_fallback
标志作为解决方法,如果您需要帮助更新您的代码,请提交一个 bug。自 2024 年 7 月 JAX 版本 0.4.31 以来,带有
native_serialization=False
或带有enable_xla=False
的jax.experimental.jax2tf.convert()
已弃用。现在我们删除了对这些用例的支持。带有本机序列化的jax2tf
仍将受到支持。在
jax.interpreters.xla
中,在 JAX v0.4.31 中弃用后,xb
、xc
和xe
符号已被删除。请改用xb = jax.lib.xla_bridge
、xc = jax.lib.xla_client
和xe = jax.lib.xla_extension
。已删除已弃用的模块
jax.experimental.export
。它在 JAX v0.4.30 中被jax.export
替换。请参阅 迁移指南,了解有关迁移到新 API 的信息。在 v0.4.27 中弃用后,已删除
jax.nn.softmax()
和jax.nn.log_softmax()
的initial
参数。现在,对类型化 PRNG 密钥(即由 :func:
jax.random.key
生成的密钥)调用np.asarray
会引发错误。以前,这会返回一个标量对象数组。删除了
jax.export
中的以下已弃用的方法和函数jax.export.DisabledSafetyCheck.shape_assertions
:它已经没有效果了。jax.export.Exported.lowering_platforms
:请使用platforms
。jax.export.Exported.mlir_module_serialization_version
:请使用calling_convention_version
。jax.export.Exported.uses_shape_polymorphism
:请使用uses_global_constants
。jax.export.export()
的lowering_platforms
kwarg:请改用platforms
。
已删除
jax.export.symbolic_args_specs()
中的 kwargssymbolic_scope
和symbolic_constraints
。它们已在 2024 年 6 月弃用。请改用scope
和constraints
。自 0.4.30 版本以来已弃用的跟踪器哈希处理现在会导致
TypeError
。重构:JAX 构建 CLI (build/build.py) 现在使用子命令结构并替换以前的 build.py 用法。运行
python build/build.py --help
了解更多详细信息。新子命令选项的简要概述build
:构建 JAX wheel 包。例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
:更新 requirements_lock.txt 文件。
jax.scipy.linalg.toeplitz()
现在对多维输入执行隐式批处理。要恢复以前的行为,您可以在函数输入上调用jax.numpy.ravel()
。jax.scipy.special.gamma()
和jax.scipy.special.gammasgn()
现在为负整数输入返回 NaN,以匹配 SciPy 的行为,网址为 https://github.com/scipy/scipy/pull/21827。在 v0.4.26 中弃用后,已删除
jax.clear_backends
。我们从保证导出稳定性的自定义调用列表中删除了自定义调用“__gpu$xla.gpu.triton”。这是因为此自定义调用依赖于 Triton IR,而 Triton IR 不能保证稳定。如果您需要导出使用此自定义调用的代码,可以使用
disabled_checks
参数。有关更多详细信息,请参阅文档。
新特性
jax.jit()
有一个新的compiler_options: dict[str, Any]
参数,用于将编译选项传递给 XLA。目前它是没有文档的,并且可能会发生变化。jax.tree_util.register_dataclass()
现在允许通过dataclasses.field()
内联声明元数据字段。有关示例,请参阅函数文档。GPU 上现在支持
jax.lax.linalg.eig()
和相关的jax.numpy
函数(jax.numpy.linalg.eig()
和jax.numpy.linalg.eigvals()
)。有关更多详细信息,请参阅 #24663。添加了两个新的配置标志
jax_exec_time_optimization_effort
和jax_memory_fitting_effort
,以控制编译器花费在最小化执行时间和内存使用量上的精力。有效值为 -1.0 到 1.0,默认为 0.0。
Bug 修复
修复了一个 bug,其中 LU 和 QR 分解的 GPU 实现会导致批处理大小接近 int32 最大值时的索引溢出。有关更多详细信息,请参阅 #24843。
弃用
jax.lib.xla_extension.ArrayImpl
和jax.lib.xla_client.ArrayImpl
已弃用;请改用jax.Array
。jax.lib.xla_extension.XlaRuntimeError
已弃用;请改用jax.errors.JaxRuntimeError
。
jax 0.4.35 (2024 年 10 月 22 日)#
重大更改
jax.numpy.isscalar()
现在对任何零维类数组对象返回 True。之前,它仅对具有弱 dtype 的零维类数组对象返回 True。jax.experimental.host_callback
自 2024 年 3 月起已弃用,JAX 版本为 0.4.26。现在我们已将其删除。有关替代方案的讨论,请参见 #20385。
更改
jax.lax.FftType
作为 FFT 操作枚举的公共名称引入。半公共 APIjax.lib.xla_client.FftType
已弃用。TPU:JAX 现在从
libtpu
包而不是libtpu-nightly
安装 TPU 支持。在接下来的几个版本中,JAX 将固定一个空的libtpu-nightly
版本以及libtpu
以简化过渡;该依赖项将在 2025 年第一季度删除。
弃用
半公共 API
jax.lib.xla_client.PaddingType
已弃用。没有 JAX API 使用此类型,因此没有替代方案。jax.pure_callback()
和jax.extend.ffi.ffi_call()
在vmap
下的默认行为已弃用,因此这些函数的vectorized
参数也已弃用。应使用vmap_method
参数以获得更好的定义行为。有关更多详细信息,请参见 #23881 中的讨论。半公共 API
jax.lib.xla_client.register_custom_call_target
已弃用。请改用 JAX FFI。半公共 API
jax.lib.xla_client.dtype_to_etype
、jax.lib.xla_client.ops
、jax.lib.xla_client.shape_from_pyval
、jax.lib.xla_client.PrimitiveType
、jax.lib.xla_client.Shape
、jax.lib.xla_client.XlaBuilder
和jax.lib.xla_client.XlaComputation
已弃用。请改用 StableHLO。
jax 0.4.34 (2024 年 10 月 4 日)#
新功能
此版本包括 Python 3.13 的 wheels。尚不支持自由线程模式。
jax.errors.JaxRuntimeError
已添加为以前私有的XlaRuntimeError
类型的公共别名。
重大更改
jax_pmap_no_rank_reduction
标志默认设置为True
。pmap 结果上的 array[0] 现在引入了一个 reshape(请改用 array[0:1])。
每个分片形状(可通过 jax_array.addressable_shards 或 jax_array.addressable_data(0) 访问)现在具有前导 (1, …)。相应地更新直接访问分片的代码。每个分片形状的秩现在与全局形状的秩匹配,这与 jit 的行为相同。这避免了在将结果从 pmap 传递到 jit 时进行代价高昂的 reshape。
jax.experimental.host_callback
自 2024 年 3 月起已弃用,JAX 版本为 0.4.26。现在我们将--jax_host_callback_legacy
配置值的默认值设置为True
,这意味着如果您的代码使用jax.experimental.host_callback
API,这些 API 调用将根据新的jax.experimental.io_callback
API 实现。如果这破坏了您的代码,您可以在很短的时间内将--jax_host_callback_legacy
设置为True
。很快我们将删除该配置选项,因此您应该改为过渡到使用新的 JAX 回调 API。有关讨论,请参见 #20385。
弃用
在
jax.numpy.trim_zeros()
中,非类数组参数或具有ndim != 1
的类数组参数现已弃用,并且将来会导致错误。内部美观打印工具
jax.core.pp_*
在 JAX v0.4.30 中弃用后已删除。jax.lib.xla_client.Device
已弃用;请改用jax.Device
。jax.lib.xla_client.XlaRuntimeError
已弃用。请改用jax.errors.JaxRuntimeError
。
删除
jax.xla_computation
已删除。自 0.4.30 JAX 版本中弃用以来已经过去了 3 个月。请使用 AOT API 来获得与jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您还可以使用
jax.stages.Lowered
的.out_info
属性来获取输出信息(如树结构、形状和 dtype)。对于跨后端降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jax.ShapeDtypeStruct
不再接受named_shape
参数。该参数仅由xmap
使用,该参数已在 0.4.31 中删除。jax.tree.map(f, None, non-None)
,以前发出DeprecationWarning
,现在会在未来的 jax 版本中引发错误。None
只是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map
将None
视为叶值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。jax.sharding.XLACompatibleSharding
已被删除。请使用jax.sharding.Sharding
。
Bug 修复
修复了如果提供非布尔输入并指定
dtype=bool
,jax.numpy.cumsum()
会产生不正确输出的错误。编辑
jax.numpy.ldexp()
的实现以获得正确的梯度。
jax 0.4.33 (2024 年 9 月 16 日)#
这是在 jax 0.4.32 之上的一个补丁版本,修复了在该版本中发现的两个错误。
在 JAX 0.4.32 固定的 libtpu 版本中发现了一个仅限 TPU 的数据损坏错误,该错误仅在同一作业中存在多个 TPU 切片时才会显现,例如,如果在多个 v5e 切片上进行训练。此版本通过固定 libtpu
的固定版本来修复该问题。
此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。
jax 0.4.32 (2024 年 9 月 11 日)#
注意:此版本已从 PyPi 中撤回,因为 TPU 上存在数据损坏错误。有关更多详细信息,请参见 0.4.33 发行说明。
新功能
添加了 Foreign function interface (FFI) 中支持的
jax.extend.ffi.ffi_call()
和jax.extend.ffi.ffi_lowering()
,用于从 JAX 与自定义 C++ 和 CUDA 代码进行接口。
更改
jax_enable_memories
标志默认设置为True
。jax.numpy
现在支持 Python Array API Standard 的 v2023.12 版本。有关更多信息,请参见 Python Array API standard。CPU 后端上的计算现在可以在更多情况下异步调度。以前,非并行计算总是同步调度的。您可以通过设置
jax.config.update('jax_cpu_enable_async_dispatch', False)
来恢复旧的行为。添加了新的
jax.process_indices()
函数来替换 JAX v0.2.13 中已弃用的jax.host_ids()
函数。为了与
numpy.fabs
的行为保持一致,jax.numpy.fabs
已被修改为不再支持complex dtypes
。jax.tree_util.register_dataclass
现在检查data_fields
和meta_fields
是否包括所有带有init=True
的数据类字段,并且只有在nodetype
是数据类时才包括它们。几个
jax.numpy
函数现在具有完整的ufunc
接口,包括add
、multiply
、bitwise_and
、bitwise_or
、bitwise_xor
、logical_and
和logical_and
。在未来的版本中,我们计划将这些扩展到其他 ufuncs。添加了
jax.lax.optimization_barrier()
,它允许用户防止编译器优化(如公共子表达式消除)并控制调度。
重大更改
MHLO MLIR 方言 (
jax.extend.mlir.mhlo
) 已被删除。请改用stablehlo
方言。
弃用
在 JAX v0.4.27 之后,不再允许使用
jax.numpy.clip()
和jax.numpy.hypot()
的复杂输入。已弃用以下 API
jax.lib.xla_bridge.xla_client
:直接使用jax.lib.xla_client
。jax.lib.xla_bridge.get_backend
:使用jax.extend.backend.get_backend()
。jax.lib.xla_bridge.default_backend
:使用jax.extend.backend.default_backend()
。
jax.experimental.array_api
模块已弃用,并且不再需要导入它才能使用 Array API。jax.numpy
直接支持数组 API;有关更多信息,请参见 Python Array API standard。内部实用程序
jax.core.check_eqn
、jax.core.check_type
和jax.core.check_valid_jaxtype
现已弃用,并将``` 删除在未来。jax.numpy.round_
已弃用,原因是 NumPy 2.0 中删除了相应的 API。请改用jax.numpy.round()
。将 DLPack capsule 传递给
jax.dlpack.from_dlpack()
已弃用。jax.dlpack.from_dlpack()
的参数应该是来自另一个框架的数组,该框架实现了__dlpack__
协议。
jaxlib 0.4.32 (2024 年 9 月 11 日)#
注意:此版本已从 PyPi 中撤回,因为 TPU 上存在数据损坏错误。有关更多详细信息,请参见 0.4.33 发行说明。
重大更改
此版本的 jaxlib 切换到了 CPU 后端的新版本,该版本应该编译得更快并更好地利用并行性。如果您因该更改遇到任何问题,您可以通过设置环境变量
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
来暂时启用旧的 CPU 后端。如果需要这样做,请提交包含重现说明的 JAX 错误。添加了 Hermetic CUDA 支持。 Hermetic CUDA 使用特定的可下载 CUDA 版本,而不是用户本地安装的 CUDA。 Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后将 CUDA 库和工具用作各种 Bazel 目标中的依赖项。这使得 JAX 及其支持的 CUDA 版本的构建更具可重复性。
更改
添加了 SparseCore 分析。
JAX 现在支持在 TPUv5p 芯片上分析 SparseCore。这些跟踪将在 Tensorboard Profiler 的 TraceViewer 中查看。
jax 0.4.31 (2024 年 7 月 29 日)#
删除
xmap 已被删除。请使用
shard_map()
作为替代。
更改
最低 CuDNN 版本为 v9.1。以前的版本也是如此,但我们现在正式声明了此版本约束。
最低 Python 版本现在为 3.10。 3.10 将在 2025 年 7 月之前仍然是最低支持的版本。
最低 NumPy 版本现在为 1.24。 NumPy 1.24 将在 2024 年 12 月之前仍然是最低支持的版本。
最低 SciPy 版本现在为 1.10。 SciPy 1.10 将在 2025 年 1 月之前仍然是最低支持的版本。
jax.numpy.ceil()
、jax.numpy.floor()
和jax.numpy.trunc()
现在返回与输入相同 dtype 的输出,即不再将整数或布尔输入向上转换为浮点数。libdevice.10.bc
不再与 CUDA wheels 打包在一起。它必须作为本地 CUDA 安装的一部分安装,或通过 NVIDIA 的 CUDA pip wheels 安装。jax.experimental.pallas.BlockSpec
现在希望在index_map
*之前* 传递block_shape
。旧的参数顺序已弃用,将在未来的版本中删除。更新了 gpu 设备的 repr,使其与 TPU/CPU 更加一致。例如,
cuda(id=0)
现在将是CudaDevice(id=0)
。作为 JAX 的 Array API 支持的一部分,向
jax.Array
添加了device
属性和to_device
方法。
弃用
删除了许多以前弃用的与多态形状相关的内部 API。从
jax.core
:删除了canonicalize_shape
、dimension_as_value
、definitely_equal
和symbolic_equal_dim
。HLO 降低规则不应再将单例 ir.Values 包装在元组中。而是返回未包装的单例 ir.Values。对包装值的支持将在未来的 JAX 版本中删除。
jax.experimental.jax2tf.convert()
与native_serialization=False
或enable_xla=False
现在已弃用,此支持将在未来的版本中删除。自 JAX 0.4.16(2023 年 9 月)以来,本机序列化一直是默认设置。以前弃用的函数
jax.random.shuffle
已被删除;请改用jax.random.permutation
和independent=True
。
jaxlib 0.4.31 (2024 年 7 月 29 日)#
Bug 修复
修复了一个错误,该错误意味着 jit 调度快速路径错误地处理了 jit 的负 static_argnums。
修复了一个错误,该错误意味着成批奇异矩阵的三角求解会产生无意义的有限值,而不是 inf 或 nan (#3589, #15429)。
jax 0.4.30 (2024 年 6 月 18 日)#
更改
JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已提升到 0.4.0,但在此版本中已回滚,以便同时使用 TensorFlow 和 JAX 的用户有更多时间迁移到更新的 TensorFlow 版本。
jax.experimental.mesh_utils
现在可以为 TPU v5e 创建一个高效的网格。jax 现在直接依赖于 jaxlib。 CUDA 插件切换启用了此更改:不再有多个 jaxlib 变体。您可以使用
pip install jax
安装仅 CPU 的 jax,无需任何额外功能。添加了用于导出和序列化 JAX 函数的 API。这过去存在于
jax.experimental.export
(已弃用)中,现在将位于jax.export
中。请参阅文档。
弃用
内部美观打印工具
jax.core.pp_*
已弃用,将在未来的版本中删除。追踪器的哈希已弃用,将在未来的 JAX 版本中导致
TypeError
。以前是这种情况,但在最近几个 JAX 版本中存在一个无意中的回归。jax.experimental.export
已弃用。请改用jax.export
。请参阅迁移指南。在大多数情况下,用数组代替 dtype 现已弃用;例如,对于数组
x
和y
,x.astype(y)
将引发警告。要消除它,请使用x.astype(y.dtype)
。jax.xla_computation
已弃用,将在未来的版本中删除。请使用 AOT API 来获得与jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您还可以使用
jax.stages.Lowered
的.out_info
属性来获取输出信息(如树结构、形状和 dtype)。对于跨后端降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jaxlib 0.4.30 (2024 年 6 月 18 日)#
已删除对整体 CUDA jaxlibs 的支持。您必须使用基于插件的安装(
pip install jax[cuda12]
或pip install jax[cuda12_local]
)。
jax 0.4.29 (2024 年 6 月 10 日)#
更改
我们预计这将是 JAX 和 jaxlib 支持整体 CUDA jaxlib 的最后一个版本。未来的版本将使用 CUDA 插件 jaxlib(例如
pip install jax[cuda12]
)。JAX 现在需要 ml_dtypes 版本 0.4.0 或更高版本。
删除了对旧的
jax.experimental.export
API 用法的向后兼容性支持。不再可能使用from jax.experimental.export import export
,而应使用from jax.experimental import export
。自 0.4.24 以来,已删除的功能已弃用。向
jax.tree.all()
&jax.tree_util.tree_all()
添加了is_leaf
参数。
弃用
jax.sharding.XLACompatibleSharding
已弃用。请使用jax.sharding.Sharding
。jax.experimental.Exported.in_shardings
已重命名为jax.experimental.Exported.in_shardings_hlo
。out_shardings
相同。旧名称将在 3 个月后删除。删除了许多以前弃用的 API
jax.numpy.linalg.matrix_rank()
的tol
参数已被弃用,并将很快删除。请改用rtol
。jax.numpy.linalg.pinv()
的rcond
参数已被弃用,并将很快删除。请改用rtol
。已删除弃用的
jax.config
子模块。要配置 JAX,请使用import jax
,然后通过jax.config
引用 config 对象。jax.random
API 不再接受批处理键,以前一些 API 无意中这样做了。展望未来,我们建议在这种情况下显式使用jax.vmap()
。在
jax.scipy.special.beta()
中,x
和y
参数已重命名为a
和b
,以便与其他beta
API 保持一致。
新功能
添加了
jax.experimental.Exported.in_shardings_jax()
以构造可与 JAX API 一起使用的分片,这些分片来自存储在Exported
对象中的 HloShardings。
jaxlib 0.4.29 (2024 年 6 月 10 日)#
Bug 修复
修复了一个错误,该错误导致 XLA 错误地分片了一些连接操作,这表现为累积减少的不正确输出 (#21403)。
修复了一个错误,该错误导致 XLA:CPU 错误地编译了某些 matmul 融合 (https://github.com/openxla/xla/pull/13301)。
修复了 GPU 上的编译器崩溃 (https://github.com/jax-ml/jax/issues/21396)。
弃用
jax.tree.map(f, None, non-None)
现在会发出DeprecationWarning
,并在未来的 jax 版本中引发错误。None
只是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map
将None
视为叶值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。
jax 0.4.28 (2024 年 5 月 9 日)#
Bug 修复
撤销了对
make_jaxpr
的一项更改,该更改破坏了 Equinox (#21116)。
弃用 & 移除
现在已删除
jax.numpy.sort()
和jax.numpy.argsort()
的kind
参数。请改用stable=True
或stable=False
。从
jax.experimental.pallas.gpu
模块中删除了get_compute_capability
。请改用由jax.devices()
或jax.local_devices()
返回的 GPU 设备的compute_capability
属性。jax.numpy.reshape()
的newshape
参数已被弃用,即将移除。请改用shape
。
更改
此版本的最低 jaxlib 版本为 0.4.27。
jaxlib 0.4.28 (2024 年 5 月 9 日)#
Bug 修复
修复了在 Python 3.10 或更早版本中 Array 和 JIT Python 对象类型名称中的内存损坏错误。
修复了 CUDA 12.4 下的警告
'+ptx84' is not a recognized feature for this target'
。修复了 CPU 上编译缓慢的问题。
更改
Windows 版本现在使用 Clang 而不是 MSVC 构建。
jax 0.4.27 (2024 年 5 月 7 日)#
新功能
添加了
jax.numpy.unstack()
和jax.numpy.cumulative_sum()
,遵循了它们在 array API 2023 标准中的添加,该标准即将被 NumPy 采用。添加了一个新的配置选项
jax_cpu_collectives_implementation
,用于选择 CPU 后端使用的跨进程集合操作的实现。可用的选项有'none'
(默认)、'gloo'
和'mpi'
(需要 jaxlib 0.4.26)。如果设置为'none'
,则禁用跨进程集合操作。
更改
jax.pure_callback()
、jax.experimental.io_callback()
和jax.debug.callback()
现在使用jax.Array
而不是np.ndarray
。您可以通过在将参数传递给回调之前通过jax.tree.map(np.asarray, args)
转换参数来恢复旧的行为。complex_arr.astype(bool)
现在遵循与 NumPy 相同的语义,在complex_arr
等于0 + 0j
时返回 False,否则返回 True。core.Token
现在是一个非平凡的类,它包装了一个jax.Array
。它可以被创建并传入和传出计算以建立依赖关系。单例对象core.token
已被删除,用户现在应该创建并使用新的core.Token
对象。在 GPU 上,Threefry PRNG 实现默认不再降低到内核调用。这种选择可以提高运行时内存使用率,但会带来编译时成本。可以通过
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
恢复产生内核调用的先前行为。如果新的默认设置导致问题,请提交错误报告。否则,我们计划在未来的版本中删除此标志。
弃用 & 移除
Pallas 现在专门使用 XLA 在 GPU 上编译内核。旧的通过 Triton Python API 的降级通道已被删除,
JAX_TRITON_COMPILE_VIA_XLA
环境变量不再有任何作用。jax.numpy.clip()
有一个新的参数签名:a
、a_min
和a_max
已被弃用,取而代之的是x
(仅限位置参数)、min
和max
(#20550)。JAX 数组的
device()
方法已移除,此前已在 JAX v0.4.21 中弃用。请改用arr.devices()
。jax.nn.softmax()
和jax.nn.log_softmax()
的initial
参数已被弃用;现在支持 softmax 的空输入,而无需设置此参数。在
jax.jit()
中,传递无效的static_argnums
或static_argnames
现在会导致错误,而不是警告。最低 jaxlib 版本现在为 0.4.23。
当传递复数值输入时,
jax.numpy.hypot()
函数现在会发出弃用警告。当弃用完成后,这将引发错误。按照 NumPy 中的类似更改,
jax.numpy.nonzero()
、jax.numpy.where()
和相关函数的标量参数现在会引发错误。配置选项
jax_cpu_enable_gloo_collectives
已弃用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
。jax.Array.device_buffer
和jax.Array.device_buffers
方法已移除,此前已在 JAX v0.4.22 中弃用。请改用jax.Array.addressable_shards
和jax.Array.addressable_data()
。jax.numpy.where
的condition
、x
和y
参数现在仅限位置参数,此前已在 JAX v0.4.21 中弃用关键字参数。jax.lax.linalg
中函数的非数组参数现在必须通过关键字指定。以前,这会引发 DeprecationWarning。现在需要在几个 :func:
jax.numpy
API 中使用类似数组的参数,包括apply_along_axis()
、apply_over_axes()
、inner()
、outer()
、cross()
、kron()
和lexsort()
。
Bug 修复
当
copy=True
时,jax.numpy.astype()
现在将始终返回副本。以前,当输出数组与输入数组具有相同的数据类型时,不会创建副本。这可能会导致一些内存使用量增加。默认值设置为copy=False
以保持向后兼容性。
jaxlib 0.4.27 (2024 年 5 月 7 日)#
jax 0.4.26 (2024 年 4 月 3 日)#
新功能
添加了
jax.numpy.trapezoid()
,遵循了 NumPy 2.0 中此函数的添加。
更改
复数值的
jax.numpy.geomspace()
现在选择与 NumPy 2.0 一致的对数螺旋分支。lax.rng_bit_generator
的行为,进而'rbg'
和'unsafe_rbg'
PRNG 实现,在jax.vmap
下 已更改,以便在键上进行映射只会导致从批处理中的第一个键生成随机数。文档现在使用
jax.random.key
来构造 PRNG 键数组,而不是jax.random.PRNGKey
。
弃用 & 移除
jax.tree_map()
已弃用;请改用jax.tree.map
,或者为了与旧 JAX 版本向后兼容,请使用jax.tree_util.tree_map()
。jax.clear_backends()
已弃用,因为它不一定能实现其名称所暗示的功能,并且可能导致意外的后果,例如,它不会销毁现有的后端并释放相应的已拥有资源。如果您只想清理编译缓存,请使用jax.clear_caches()
。为了向后兼容,或者您确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()
。jax.experimental.maps
模块和jax.experimental.maps.xmap
已弃用。请使用jax.experimental.shard_map
或jax.vmap
以及spmd_axis_name
参数来表达 SPMD 设备并行计算。jax.experimental.host_callback
模块已弃用。请改用新的 JAX 外部回调。添加了JAX_HOST_CALLBACK_LEGACY
标志以帮助过渡到新的回调。有关讨论,请参见 #20385。现在,将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
会导致异常。已删除已弃用的标志
jax_parallel_functions_output_gda
。此标志已被长期弃用且没有任何作用;它的使用是一个空操作。先前已弃用的导入
jax.interpreters.ad.config
和jax.interpreters.ad.source_info_util
现已删除。请改用jax.config
和jax.extend.source_info_util
。JAX 导出不再支持较旧的序列化版本。自 2023 年 10 月 27 日起已支持版本 9,自 2024 年 2 月 1 日起已成为默认版本。请参阅 版本描述。此更改可能会破坏设置了低于 9 的特定 JAX 序列化版本的客户端。
jaxlib 0.4.26 (2024 年 4 月 3 日)#
更改
JAX 现在仅支持 CUDA 12.1 或更高版本。已删除对 CUDA 11.8 的支持。
JAX 现在支持 NumPy 2.0。
jax 0.4.25 (2024 年 2 月 26 日)#
新特性
添加了 CUDA Array Interface 导入支持(需要 jaxlib 0.4.24)。
JAX 数组现在支持 NumPy 风格的标量布尔索引,例如
x[True]
或x[False]
。添加了
jax.tree
模块,它提供了一个更方便的接口来引用jax.tree_util
中的函数。jax.tree.transpose()
(即jax.tree_util.tree_transpose()
)现在接受inner_treedef=None
,在这种情况下,内部 treedef 将自动推断。
更改
Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
JAX_TRITON_COMPILE_VIA_XLA
环境变量设置为"0"
来恢复旧的行为。v0.4.24 中删除的
jax.interpreters.xla
中的几个已弃用的 API 已在 v0.4.25 中重新添加,包括backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
和XLAOp
。这些仍然被认为是已弃用的,并且将来在有更好的替代品可用时将被再次删除。有关讨论,请参阅 #19816。
弃用 & 移除
jax.numpy.linalg.solve()
现在显示对具有b.ndim > 1
的批处理 1D 求解的弃用警告。将来,这些将被视为批处理 2D 求解。无论数组的大小如何,将非标量数组转换为 Python 标量现在都会引发错误。以前,在大小为 1 的非标量数组的情况下会引发弃用警告。这遵循 NumPy 中的类似弃用。
以下之前已弃用的配置 API 已按照标准的 3 个月弃用周期(请参阅API 兼容性)删除。其中包括
jax.config.config
对象和jax.config
的define_*_state
和DEFINE_*
方法。
通过
import jax.config
导入jax.config
子模块已弃用。要配置 JAX,请使用import jax
,然后通过jax.config
引用 config 对象。最低 jaxlib 版本现在为 0.4.20。
jaxlib 0.4.25 (2024 年 2 月 26 日)#
jax 0.4.24 (2024 年 2 月 6 日)#
更改
JAX 降低到 StableHLO 不再依赖于物理设备。如果您的原语包装了 custom_partitioning 或 JAX 回调到降低规则中,即传递给
mlir.register_lowering
的rule
参数的函数,则将您的原语添加到jax._src.dispatch.prim_requires_devices_during_lowering
集中。这是必需的,因为 custom_partitioning 和 JAX 回调需要物理设备才能在降低期间创建Sharding
。这是一个临时状态,直到我们可以在没有物理设备的情况下创建Sharding
。jax.numpy.argsort()
和jax.numpy.sort()
现在支持stable
和descending
参数。形状多态性处理的几个更改(在
jax.experimental.jax2tf
和jax.experimental.export
中使用)更清晰的符号表达式漂亮打印 (#19227)
添加了指定维度变量的符号约束的能力。这使得形状多态性更具表现力,并提供了一种解决不等式推理限制的方法。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
通过添加符号约束 (#19235),我们现在认为来自不同范围的维度变量是不同的,即使它们具有相同的名称。来自不同范围的符号表达式不能交互,例如,在算术运算中。范围由
jax.experimental.jax2tf.convert()
、jax.experimental.export.symbolic_shape()
、jax.experimental.export.symbolic_args_specs()
引入。可以使用e.scope
读取符号表达式e
的范围,并将其传递到上述函数中,以指示它们在给定范围内构造符号表达式。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。简化和更快的相等性比较,我们认为两个符号维度相等,如果它们的差异的标准化形式减少到 0 (#19231; 请注意,这可能会导致用户可见的行为更改)
改进了不确定的不等式比较的错误消息 (#19235)。
core.non_negative_dim
API(最近引入)已弃用,并引入了core.max_dim
和core.min_dim
(#18953) 以表达符号维度的max
和min
。您可以使用core.max_dim(d, 0)
代替core.non_negative_dim(d)
。shape_poly.is_poly_dim
已弃用,建议使用export.is_symbolic_dim
(#19282)。export.args_specs
已弃用,建议使用export.symbolic_args_specs ({jax-issue}
#19283`)。shape_poly.PolyShape
和jax2tf.PolyShape
已弃用,对多态形状规范使用字符串 (#19284)。JAX 默认的本机序列化版本现在为 9。这与
jax.experimental.jax2tf
和jax.experimental.export
相关。请参阅版本号说明。
重构了
jax.experimental.export
的 API。您现在应该使用from jax.experimental import export
,而不是from jax.experimental.export import export
。旧的导入方式将在 3 个月的弃用期内继续有效。具有
return_inverse = True
的jax.numpy.unique()
返回重塑为输入维度的逆索引,遵循numpy.unique()
在 NumPy 2.0 中的类似更改。jax.numpy.sign()
现在为非零复数输入返回x / abs(x)
。这与 NumPy 2.0 中numpy.sign()
的行为一致。具有
return_sign=True
的jax.scipy.special.logsumexp()
现在对复数符号使用 NumPy 2.0 约定x / abs(x)
。这与 SciPy v1.13 中scipy.special.logsumexp()
的行为一致。JAX 现在支持 bool DLPack 类型,用于导入和导出。以前,bool 值无法导入,并且导出为整数。
弃用 & 移除
许多以前已弃用的函数已删除,遵循标准的 3+ 个月弃用周期(请参阅API 兼容性)。其中包括
从
jax.core
中:TracerArrayConversionError
、TracerIntegerConversionError
、UnexpectedTracerError
、as_hashable_function
、collections
、dtypes
、lu
、map
、namedtuple
、partial
、pp
、ref
、safe_zip
、safe_map
、source_info_util
、total_ordering
、traceback_util
、tuple_delete
、tuple_insert
和zip
。从
jax.lax
中:dtypes
、itertools
、naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
。jax.linear_util
子模块及其所有内容。jax.prng
子模块及其所有内容。来自
jax.random
:PRNGKeyArray
、KeyArray
、default_prng_impl
、threefry_2x32
、threefry2x32_key
、threefry2x32_p
、rbg_key
和unsafe_rbg_key
。来自
jax.tree_util
:register_keypaths
、AttributeKeyPathEntry
和GetItemKeyPathEntry
。来自
jax.interpreters.xla
:backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
、axis_groups
、ShapedArray
、ConcreteArray
、AxisEnv
、backend_compile
和XLAOp
。来自
jax.numpy
:NINF
、NZERO
、PZERO
、row_stack
、issubsctype
、trapz
和in1d
。来自
jax.scipy.linalg
:tril
和triu
。
之前已弃用的方法
PRNGKeyArray.unsafe_raw_array
已被移除。请使用jax.random.key_data()
代替。bool(empty_array)
现在会引发错误,而不是返回False
。之前会引发弃用警告,并且遵循 NumPy 中的类似更改。对 mhlo MLIR 方言的支持已被弃用。JAX 不再使用 mhlo 方言,而是使用 stablehlo。将来会删除引用“mhlo”的 API。请改用“stablehlo”方言。
jax.random
:将批处理键直接传递给随机数生成函数(例如bits()
、gamma()
等)已被弃用,并将发出FutureWarning
。使用jax.vmap
进行显式批处理。jax.lax.tie_in()
已被弃用:自 JAX v0.2.0 以来,它一直是一个空操作。
jaxlib 0.4.24 (2024 年 2 月 6 日)#
更改
JAX 现在支持 CUDA 12.3 和 CUDA 11.8。已删除对 CUDA 12.2 的支持。
cost_analysis
现在可以与交叉编译的Compiled
对象一起使用(即,当使用带有拓扑对象的.lower().compile()
时,例如,从非 TPU 计算机编译用于 Cloud TPU)。添加了 CUDA Array Interface 导入支持(需要 jax 0.4.25)。
jax 0.4.23 (2023 年 12 月 13 日)#
jaxlib 0.4.23 (2023 年 12 月 13 日)#
修复了 GPU 编译器在编译期间导致冗长日志记录的错误。
jax 0.4.22 (2023 年 12 月 13 日)#
弃用
JAX 数组的
device_buffer
和device_buffers
属性已被弃用。显式缓冲区已被更灵活的数组分片接口取代,但可以通过以下方式恢复先前的输出arr.device_buffer
变为arr.addressable_data(0)
arr.device_buffers
变为[x.data for x in arr.addressable_shards]
jaxlib 0.4.22 (2023 年 12 月 13 日)#
jax 0.4.21 (2023 年 12 月 4 日)#
新特性
添加了
jax.nn.squareplus
。
更改
最低 jaxlib 版本现在为 0.4.19。
发布的 wheels 现在使用 clang 而不是 gcc 构建。
强制在调用
jax.distributed.initialize()
之前未初始化设备后端。自动执行云 TPU 环境中
jax.distributed.initialize()
的参数。
弃用
先前已弃用的
sym_pos
参数已从jax.scipy.linalg.solve()
中移除。请改用assume_a='pos'
。将
None
直接或在列表或元组中传递给jax.array()
或jax.asarray()
已被弃用,现在会引发FutureWarning
。目前它被转换为 NaN,未来会引发TypeError
。为了与
numpy.where
匹配,按关键字参数将condition
、x
和y
参数传递给jax.numpy.where
已被弃用。将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
已被弃用,现在会引发DeprecationWaning
。目前,这些函数返回 False,将来会引发异常。JAX 数组的
device()
方法已被弃用。根据上下文,可以用以下方法之一代替jax.Array.devices()
返回数组使用的所有设备的集合。jax.Array.sharding
提供数组使用的分片配置。
jaxlib 0.4.21 (2023 年 12 月 4 日)#
更改
为了准备添加分布式 CPU 支持,JAX 现在将 CPU 设备与 GPU 和 TPU 设备相同对待,即
jax.devices()
包括分布式作业中存在的所有设备,甚至是不属于当前进程的设备。jax.local_devices()
仍然只包括属于当前进程的设备,因此如果对jax.devices()
的更改破坏了您的代码,您很可能想要改用jax.local_devices()
。CPU 设备现在在分布式作业中接收全局唯一的 ID 号;以前 CPU 设备会接收进程本地 ID 号。
每个 CPU 设备的
process_index
现在将与同一进程中的任何 GPU 或 TPU 设备匹配;以前 CPU 设备的process_index
始终为 0。
在 NVIDIA GPU 上,JAX 现在更喜欢 Jacobi SVD 求解器,用于最大 1024x1024 的矩阵。Jacobi 求解器似乎比非 Jacobi 版本更快。
Bug 修复
修复了将具有非有限值的数组传递给非对称特征分解时出现的错误/挂起 (#18226)。现在,具有非有限值的数组会生成充满 NaN 的数组作为输出。
jax 0.4.20 (2023 年 11 月 2 日)#
jaxlib 0.4.20 (2023 年 11 月 2 日)#
Bug 修复
修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆。
jax 0.4.19 (2023 年 10 月 19 日)#
新特性
添加了
jax.typing.DTypeLike
,可用于注释可转换为 JAX dtypes 的对象。添加了
jax.numpy.fill_diagonal
。
更改
JAX 现在需要 SciPy 1.9 或更高版本。
Bug 修复
在多控制器分布式 JAX 程序中,只有进程 0 会写入持久编译缓存条目。如果缓存位于 GCS 等网络文件系统上,则可以修复写入争用。
在确定已安装的这些库的版本是否至少与 JAX 构建时使用的版本一样新时,cusolver 和 cufft 的版本检查不再考虑补丁版本。
jaxlib 0.4.19 (2023 年 10 月 19 日)#
更改
如果安装了通过 pip 安装的 NVIDIA CUDA 库(nvidia-… 包),包括
LD_LIBRARY_PATH
中命名的安装,jaxlib 现在始终优先使用它们。如果这导致问题,并且目的是使用系统安装的 CUDA,则解决方法是删除 pip 安装的 CUDA 库包。
jax 0.4.18 (2023 年 10 月 6 日)#
jaxlib 0.4.18 (2023 年 10 月 6 日)#
更改
CUDA jaxlibs 现在依赖用户来安装兼容的 NCCL 版本。如果使用推荐的
cuda12_pip
安装,则应自动安装 NCCL。当前,需要 NCCL 2.16 或更高版本。我们现在提供 Linux aarch64 wheels,包括有和没有 NVIDIA GPU 支持的。
jax.Array.item()
现在支持可选的索引参数。
弃用
jax.lax
中的许多内部实用程序和意外导出已被弃用,并将在未来的版本中删除。jax.lax.dtypes
:请改用jax.dtypes
。jax.lax.itertools
:请改用itertools
。naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
是内部实用程序,现在已被弃用,没有替代方法。
Bug 修复
修复了 Cloud TPU 回归,其中由于 smem 导致编译 OOM。
jax 0.4.17 (2023 年 10 月 3 日)#
新功能
添加了新的
jax.numpy.bitwise_count()
函数,与最近添加到 NumPy 的类似函数的 API 匹配。
弃用
删除了已弃用的模块
jax.abstract_arrays
及其所有内容。jax.random
中的命名键构造函数已被弃用。请将impl
参数传递给jax.random.PRNGKey()
或jax.random.key()
代替random.threefry2x32_key(seed)
变为random.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
变为random.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
变为random.PRNGKey(seed, impl='unsafe_rbg')
更改
CUDA:JAX 现在验证它找到的 CUDA 库至少与 JAX 构建时使用的 CUDA 库一样新。如果找到较旧的库,JAX 会引发异常,因为这比神秘的故障和崩溃更好。
删除了“未找到 GPU/TPU”警告。相反,如果在 Linux 上发现 NVIDIA GPU 或 Google TPU 但未使用,并且未指定
--jax_platforms
,则发出警告。jax.scipy.stats.mode()
现在如果模式是在大小为 0 的轴上获取的,则返回 0 计数,与 SciPy 1.11 中的scipy.stats.mode
的行为匹配。大多数
jax.numpy
函数和属性现在都有完全定义的类型存根。以前,许多这些函数和属性都被静态类型检查器(如mypy
和pytype
)视为Any
。
jaxlib 0.4.17 (2023 年 10 月 3 日)#
更改
在此版本中添加了 Python 3.12 wheels。
CUDA 12 wheels 现在需要 CUDA 12.2 或更高版本以及 cuDNN 8.9.4 或更高版本。
Bug 修复
修复了初始化 JAX CPU 后端时来自 ABSL 的日志垃圾邮件。
jax 0.4.16 (2023 年 9 月 18 日)#
更改
添加了
jax.numpy.ufunc
,以及jax.numpy.frompyfunc()
,它可以将任何标量值函数转换为类似于numpy.ufunc()
的对象,并具有outer()
、reduce()
、accumulate()
、at()
和reduceat()
等方法 (#17054)。不在 IPython 下运行时:当引发异常时,JAX 现在会从回溯中过滤掉其所有内部帧。(没有先前出现的“未过滤的堆栈跟踪”。)这应该会产生更友好的回溯。有关示例,请参见此处。可以通过设置
JAX_TRACEBACK_FILTERING=remove_frames
(对于两个单独的未过滤/过滤回溯,这是旧行为)或JAX_TRACEBACK_FILTERING=off
(对于一个未过滤的回溯)来更改此行为。jax2tf 默认序列化版本现在为 7,这引入了新的形状安全断言。
传递给
jax.sharding.Mesh
的设备应该是可哈希的。这专门适用于模拟设备或用户创建的设备。jax.devices()
已经是可哈希的。
重大更改
jax2tf 现在默认使用原生序列化。有关详细信息以及覆盖默认值的机制,请参见jax2tf 文档。
选项
--jax_coordination_service
已被移除。它现在始终为True
。jax.jaxpr_util
已从公共 JAX 命名空间中移除。JAX_USE_PJRT_C_API_ON_TPU
不再有效(即,它始终默认为 true)。2021 年 12 月引入的向后兼容性标志
--jax_host_callback_ad_transforms
已被移除。
弃用
根据 NumPy NEP-52,已弃用多个
jax.numpy
APIjax.numpy.NINF
已被弃用。请改用-jax.numpy.inf
。jax.numpy.PZERO
已被弃用。请改用0.0
。jax.numpy.NZERO
已被弃用。请改用-0.0
。jax.numpy.issubsctype(x, t)
已被弃用。请使用jax.numpy.issubdtype(x.dtype, t)
。jax.numpy.row_stack
已被弃用。请改用jax.numpy.vstack
。jax.numpy.in1d
已被弃用。请改用jax.numpy.isin
。jax.numpy.trapz
已被弃用。请改用jax.scipy.integrate.trapezoid
。
按照 SciPy 的做法,
jax.scipy.linalg.tril
和jax.scipy.linalg.triu
已被弃用。请改用jax.numpy.tril
和jax.numpy.triu
。jax.lax.prod
在 JAX v0.4.11 中弃用后已被移除。请改用内置的math.prod
。与定义自定义 JAX 基元的 HLO 降级规则相关的
jax.interpreters.xla
中的多个导出已被弃用。应改用jax.interpreters.mlir
中的 StableHLO 降级实用程序定义自定义基元。以下先前已弃用的函数在三个月的弃用期后已被移除
jax.abstract_arrays.ShapedArray
:请使用jax.core.ShapedArray
。jax.abstract_arrays.raise_to_shaped
:请使用jax.core.raise_to_shaped
。jax.numpy.alltrue
:请使用jax.numpy.all
。jax.numpy.sometrue
:请使用jax.numpy.any
。jax.numpy.product
:请使用jax.numpy.prod
。jax.numpy.cumproduct
:请使用jax.numpy.cumprod
。
弃用/移除
内部子模块
jax.prng
现在已被弃用。其内容可在jax.extend.random
中找到。内部子模块路径
jax.linear_util
已被弃用。请改用jax.extend.linear_util
(jax.extend:扩展模块的一部分)jax.random.PRNGKeyArray
和jax.random.KeyArray
已被弃用。请使用jax.Array
进行类型注释,使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
进行类型化 prng 键的运行时检测。方法
PRNGKeyArray.unsafe_raw_array
已被弃用。请使用jax.random.key_data()
代替。jax.experimental.pjit.with_sharding_constraint
已被弃用。请改用jax.lax.with_sharding_constraint
。内部实用程序
jax.core.is_opaque_dtype
和jax.core.has_opaque_dtype
已被移除。不透明 dtypes 已重命名为扩展 dtypes;请改用jnp.issubdtype(dtype, jax.dtypes.extended)
(自 jax v0.4.14 起可用)。实用程序
jax.interpreters.xla.register_collective_primitive
已被移除。此实用程序在最近的 JAX 版本中没有任何作用,并且可以安全地删除对其的调用。内部子模块路径
jax.linear_util
已被弃用。请改用jax.extend.linear_util
(jax.extend:扩展模块的一部分)
jaxlib 0.4.16 (2023 年 9 月 18 日)#
更改
通过实验性 jax sparse API 进行的稀疏 CSR 矩阵乘法不再在 NVIDIA GPU 上使用确定性算法。进行此更改是为了提高与 CUDA 12.2.1 的兼容性。
Bug 修复
修复了由于与乱序部分和 IMAGE_REL_AMD64_ADDR32NB 重定位相关的致命 LLVM 错误导致的 Windows 上的崩溃 (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14 (2023 年 7 月 27 日)#
更改
jax.jit
将donate_argnames
作为参数。其语义与static_argnames
相似。如果未提供 donate_argnums 和 donate_argnames,则不会捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,则 JAX 使用inspect.signature(fun)
来查找与 donate_argnames 对应的任何位置参数(反之亦然)。如果同时提供了 donate_argnums 和 donate_argnames,则不使用 inspect.signature,并且只会捐赠 donate_argnums 或 donate_argnames 中列出的实际参数。jax.random.gamma()
已重构为更高效的算法,具有更强大的端点行为 (#16779)。这意味着对于给定的key
,在 JAX v0.4.13 和 v0.4.14 之间返回的值序列将为gamma
和相关的采样器(包括jax.random.ball()
、jax.random.beta()
、jax.random.chisquare()
、jax.random.dirichlet()
、jax.random.generalized_normal()
、jax.random.loggamma()
、jax.random.t()
)而更改。
删除
自
in_axis_resources
和out_axis_resources
被弃用以来已经超过 3 个月,因此已从 pjit 中删除。请使用in_shardings
和out_shardings
作为替代。这是一个安全且简单的名称替换。它不会更改当前 pjit 的任何语义,也不会破坏任何代码。您仍然可以将PartitionSpecs
传递给 in_shardings 和 out_shardings。
弃用
根据 https://jax.net.cn/en/latest/deprecation.html,已停止支持 Python 3.8
根据 https://jax.net.cn/en/latest/deprecation.html,JAX 现在需要 NumPy 1.22 或更高版本
在 JAX 0.4.7 版本中弃用后,不再支持按位置将可选参数传递给
jax.numpy.ndarray.at()
。例如,请使用x.at[i].get(indices_are_sorted=True)
,而不是x.at[i].get(True)
以下
jax.Array
方法在 JAX v0.4.5 中弃用后已被删除jax.Array.broadcast
:请改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:请改用jax.lax.broadcast_in_dim()
。jax.Array.split
:请改用jax.numpy.split()
。
以下 API 在先前弃用后已被删除
jax.ad
:请使用jax.interpreters.ad
。jax.curry
:请使用curry = lambda f: partial(partial, f)
。jax.partial_eval
:请使用jax.interpreters.partial_eval
。jax.pxla
:请使用jax.interpreters.pxla
。jax.xla
:请使用jax.interpreters.xla
。jax.ShapedArray
:请使用jax.core.ShapedArray
。jax.interpreters.pxla.device_put
:请使用jax.device_put()
。jax.interpreters.pxla.make_sharded_device_array
:请使用jax.make_array_from_single_device_arrays()
。jax.interpreters.pxla.ShardedDeviceArray
:请使用jax.Array
。jax.numpy.DeviceArray
:请使用jax.Array
。jax.stages.Compiled.compiler_ir
:请使用jax.stages.Compiled.as_text()
。
重大更改
JAX 现在需要 ml_dtypes 0.2.0 或更高版本。
为了修复一个边缘情况,如果第二个和第三个参数是可调用的,即使其他操作数也是可调用的,则对具有五个参数的
jax.lax.cond()
的调用将始终解析为“公共操作数”cond
行为(如文档所述)。请参阅 #16413。已删除已弃用的配置选项
jax_array
和jax_jit_pjit_api_merge
,它们没有任何作用。这些选项在许多版本中默认都为 true。
新功能
JAX 现在支持一个配置标志 –jax_serialization_version 和一个 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本 (#16746)。
如果序列化版本至少为 7,则存在形状多态性的 jax2tf 现在会生成代码来检查某些形状约束。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
jaxlib 0.4.14 (2023 年 7 月 27 日)#
弃用
根据 https://jax.net.cn/en/latest/deprecation.html,已停止支持 Python 3.8
jax 0.4.13 (2023 年 6 月 22 日)#
更改
jax.jit
现在允许将None
传递给in_shardings
和out_shardings
。语义如下对于 in_shardings,JAX 会将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依靠 XLA GSPMD 分区器来确定输出分片。
jax.experimental.pjit.pjit
也允许将None
传递给in_shardings
和out_shardings
。语义如下如果未提供 mesh 上下文管理器,则 JAX 可以自由选择其想要的任何分片。
对于 in_shardings,JAX 会将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依靠 XLA GSPMD 分区器来确定输出分片。
如果提供了 mesh 上下文管理器,则 None 将意味着该值将在 mesh 的所有设备上复制。
Executable.cost_analysis() 适用于 Cloud TPU
如果正在使用非白名单
jaxlib
插件,则添加了一个警告。添加了
jax.tree_util.tree_leaves_with_path
。None
不是jax.experimental.multihost_utils.host_local_array_to_global_array
或jax.experimental.multihost_utils.global_array_to_host_local_array
的有效输入。如果您想复制您的输入,请使用jax.sharding.PartitionSpec()
。
Bug 修复
修复了 CUDA 12 版本中不正确的 wheel 名称 (#16362);正确的 wheel 名称是
cudnn89
,而不是cudnn88
。
弃用
jax.experimental.jax2tf.convert()
的native_serialization_strict_checks
参数已弃用,转而使用新的native_serializaation_disabled_checks
(#16347)。
jaxlib 0.4.13 (2023 年 6 月 22 日)#
更改
将 Windows 仅 CPU wheel 添加到
jaxlib
Pypi 版本。
Bug 修复
__cuda_array_interface__
在以前的 jaxlib 版本中已损坏,现在已修复 (#16440)。并发 CUDA 内核跟踪现在默认在 NVIDIA GPU 上启用。
jax 0.4.12 (2023 年 6 月 8 日)#
更改
弃用
jax.abstract_arrays
及其内容现已弃用。请参阅 :mod:jax.core
中的相关功能。jax.numpy.alltrue
:请使用jax.numpy.all
。这遵循了 NumPy 1.25.0 中numpy.alltrue
的弃用。jax.numpy.sometrue
:请使用jax.numpy.any
。这遵循了 NumPy 1.25.0 中numpy.sometrue
的弃用。jax.numpy.product
:请使用jax.numpy.prod
。这遵循了 NumPy 1.25.0 中numpy.product
的弃用。jax.numpy.cumproduct
:请使用jax.numpy.cumprod
。这遵循了 NumPy 1.25.0 中numpy.cumproduct
的弃用。jax.sharding.OpShardingSharding
已被删除,因为它已被弃用 3 个月。
jaxlib 0.4.12 (2023 年 6 月 8 日)#
更改
包括 Hopper (SM 版本 9.0+) GPU 的 PTX/SASS。以前版本的 jaxlib 应该可以在 Hopper 上工作,但在第一次执行 JAX 操作时会有很长的 JIT 编译延迟。
Bug 修复
修复了 Python 3.11 下 JAX 生成的 Python 回溯中不正确的源行信息。
修复了在 JAX 生成的 Python 回溯中打印帧的局部变量时发生的崩溃 (#16027)。
jax 0.4.11 (2023 年 5 月 31 日)#
弃用
根据 API 兼容性 策略,以下 API 在 3 个月的弃用期后已被删除
jax.experimental.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.experimental.maps.Mesh
:请使用jax.sharding.Mesh
jax.experimental.pjit.NamedSharding
:请使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.experimental.pjit.FROM_GDA
。请改为传递分片的jax.Array
对象作为输入,并删除pjit
的可选in_shardings
参数。jax.interpreters.pxla.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:请使用jax.sharding.Mesh
jax.interpreters.xla.Buffer
:请使用jax.Array
。jax.interpreters.xla.Device
:请使用jax.Device
。jax.interpreters.xla.DeviceArray
:请使用jax.Array
。jax.interpreters.xla.device_put
:请使用jax.device_put
。jax.interpreters.xla.xla_call_p
:请使用jax.experimental.pjit.pjit_p
。已删除
with_sharding_constraint
的axis_resources
参数。请改用shardings
。
jaxlib 0.4.11 (2023 年 5 月 31 日)#
更改
将
memory_stats()
方法添加到Device
。如果支持,这将返回一个包含字符串统计名称和 int 值的字典,例如"bytes_in_use"
,如果平台不支持内存统计信息,则返回 None。返回的确切统计信息可能因平台而异。目前仅在 Cloud TPU 上实现。重新添加了对 CPU 设备上 Python 缓冲区协议 (
memoryview
) 的支持。
jax 0.4.10 (2023 年 5 月 11 日)#
jaxlib 0.4.10 (2023 年 5 月 11 日)#
更改
修复了
'apple-m1' is not a recognized processor for this target (ignoring processor)
问题,该问题阻止了先前版本在 Mac M1 上运行。
jax 0.4.9 (2023 年 5 月 9 日)#
更改
已删除标志 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap。它们现在始终处于启用状态。
提高了 TPU 上奇异值分解 (SVD) 的准确性(需要 jaxlib 0.4.9)。
弃用
jax.experimental.gda_serialization
已弃用,并已重命名为jax.experimental.array_serialization
。请更改您的导入以使用jax.experimental.array_serialization
。pjit 的
in_axis_resources
和out_axis_resources
参数已弃用。请分别使用in_shardings
和out_shardings
。已删除函数
jax.numpy.msort
。它自 JAX v0.4.1 起已被弃用。请改用jnp.sort(a, axis=0)
。已从
jax.xla_computation
中删除in_parts
和out_parts
参数,因为它们仅与 sharded_jit 一起使用,而 sharded_jit 早已消失。已从
jax.xla_computation
中删除instantiate_const_outputs
参数,因为它已经很长时间未使用了。
jaxlib 0.4.9 (2023 年 5 月 9 日)#
jax 0.4.8 (2023 年 3 月 29 日)#
重大更改
Cloud TPU 运行时的主要组件已升级。这在 Cloud TPU 上启用了以下新功能
jax.debug.print()
、jax.debug.callback()
和jax.debug.breakpoint()
现在可以在 Cloud TPU 上工作自动 TPU 内存碎片整理
新的运行时组件不再支持 Cloud TPU 上的
jax.experimental.host_callback()
。如果新的jax.debug
API 不足以满足您的用例,请在 JAX 问题跟踪器上提交问题。旧的运行时组件将在至少未来三个月内可用,方法是设置环境变量
JAX_USE_PJRT_C_API_ON_TPU=false
。如果您发现出于任何原因需要禁用新运行时,请在 JAX 问题跟踪器上告诉我们。
更改
最低 jaxlib 版本已从 0.4.6 提升到 0.4.7。
弃用
已停止支持 CUDA 11.4。JAX GPU wheel 仅支持 CUDA 11.8 和 CUDA 12。如果 jaxlib 从源代码构建,则较旧的 CUDA 版本可能会工作。
pmap 的
global_arg_shapes
参数仅适用于 sharded_jit,并且已从 pmap 中删除。请迁移到 pjit 并从 pmap 中删除 global_arg_shapes。
jax 0.4.7 (2023 年 3 月 27 日)#
更改
根据 https://jax.net.cn/en/latest/jax_array_migration.html#jax-array-migration,
jax.config.jax_array
无法再被禁用。jax.config.jax_jit_pjit_api_merge
无法再被禁用。jax.experimental.jax2tf.convert()
现在支持native_serialization
参数,该参数使用 JAX 的本机降级到 StableHLO 来获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原语降级到 TensorFlow op。这简化了内部结构,并提高了您序列化的内容与 JAX 本机语义相匹配的信心。请参阅 文档。作为此更改的一部分,配置标志--jax2tf_default_experimental_native_lowering
已重命名为--jax2tf_native_serialization
。JAX 现在依赖于
ml_dtypes
,其中包含 NumPy 类型的定义,例如 bfloat16。这些定义以前是 JAX 的内部定义,但已拆分为一个单独的软件包,以便于与其他项目共享。JAX 现在需要 NumPy 1.21 或更高版本和 SciPy 1.7 或更高版本。
弃用
类型
jax.numpy.DeviceArray
已弃用。请改用jax.Array
,它是它的别名。类型
jax.interpreters.pxla.ShardedDeviceArray
已弃用。请改用jax.Array
。按位置将附加参数传递给
jax.numpy.ndarray.at()
已弃用。例如,请使用x.at[i].get(indices_are_sorted=True)
,而不是x.at[i].get(True)
jax.interpreters.xla.device_put
已弃用。请使用jax.device_put
。jax.interpreters.pxla.device_put
已弃用。请使用jax.device_put
。jax.experimental.pjit.FROM_GDA
已弃用。请传入分片的 jax.Arrays 作为输入,并从 pjit 中删除in_shardings
参数,因为它是可选的。
jaxlib 0.4.7 (2023 年 3 月 27 日)#
更改
jaxlib 现在依赖于
ml_dtypes
,其中包含 NumPy 类型的定义,例如 bfloat16。这些定义以前是 JAX 的内部定义,但已拆分为一个单独的软件包,以便于与其他项目共享。
jax 0.4.6 (2023 年 3 月 9 日)#
更改
jax.tree_util
现在包含一组 API,允许用户为其自定义 pytree 节点定义键。这包括tree_flatten_with_path
,它展平树,不仅返回每个叶子,还返回它们的键路径。tree_map_with_path
,它可以映射一个将键路径作为参数的函数。register_pytree_with_keys
,用于注册键路径和叶子在自定义 pytree 节点中的外观。keystr
,它很好地打印键路径。
jax2tf.call_tf()
有一个新的参数output_shape_dtype
(默认None
),可用于声明结果的输出形状和类型。这使jax2tf.call_tf()
能够在形状多态性的存在下工作。(#14734)。
弃用
jax.tree_util
中的旧键路径 API 已弃用,并将于 2023 年 3 月 10 日起 3 个月后删除register_keypaths
:请改用jax.tree_util.register_pytree_with_keys()
。AttributeKeyPathEntry
:请改用GetAttrKey
。GetitemKeyPathEntry
:请改用SequenceKey
或DictKey
。
jaxlib 0.4.6 (2023 年 3 月 9 日)#
jax 0.4.5 (2023 年 3 月 2 日)#
弃用
jax.sharding.OpShardingSharding
已重命名为jax.sharding.GSPMDSharding
。jax.sharding.OpShardingSharding
将于 2023 年 2 月 17 日起 3 个月后删除。以下
jax.Array
方法已弃用,将于 2023 年 2 月 23 日起 3 个月后删除jax.Array.broadcast
:请改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:请改用jax.lax.broadcast_in_dim()
。jax.Array.split
:请改用jax.numpy.split()
。
jax 0.4.4 (2023 年 2 月 16 日)#
更改
已合并
jit
和pjit
的实现。合并 jit 和 pjit 更改了 JAX 的内部结构,而不影响 JAX 的公共 API。以前,jit
是一种最终样式原语。最终样式意味着 jaxpr 的创建尽可能延迟,并且转换堆叠在彼此之上。通过jit
-pjit
实现合并,jit
成为一种初始样式原语,这意味着我们尽早跟踪到 jaxpr。有关更多信息,请参阅 autodidax 中的此部分。移动到初始样式应该简化 JAX 的内部结构,并使开发动态形状等功能更容易。您只能通过环境变量禁用它,例如os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
。合并必须通过环境变量禁用,因为它在导入时影响 JAX,因此需要在导入 jax 之前禁用它。已弃用
with_sharding_constraint
的axis_resources
参数。请改用shardings
。如果您使用axis_resources
作为 arg,则无需进行任何更改。如果您使用它作为 kwarg,请改用shardings
。axis_resources
将在 2023 年 2 月 13 日起 3 个月后删除。添加了
jax.typing
模块,其中包含 JAX 函数的类型注释工具。以下名称已被弃用
jax.xla.Device
和jax.interpreters.xla.Device
:请使用jax.Device
。jax.experimental.maps.Mesh
。请改用jax.sharding.Mesh
。jax.experimental.pjit.NamedSharding
:请使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:请使用jax.sharding.Mesh
。jax.interpreters.pxla.PartitionSpec
:请使用jax.sharding.PartitionSpec
。
重大更改
与相应的 NumPy API 一致,现在要求像 :func:
jax.numpy.sum
这样的归约函数的initial
参数必须是标量。之前将输出广播到非标量initial
值的行为是一个无意的实现细节 (#14446)。
jaxlib 0.4.4 (2023 年 2 月 16 日)#
重大更改
默认
jaxlib
版本已删除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以从源代码构建具有 Kepler 支持的jaxlib
(通过build.py
的--cuda_compute_capabilities=sm_35
选项),但请注意 CUDA 12 已完全停止对 Kepler GPU 的支持。
jax 0.4.3 (2023 年 2 月 8 日)#
重大更改
删除了
jax.scipy.linalg.polar_unitary()
,这是一个已弃用的 JAX 扩展到 scipy API。请改用jax.scipy.linalg.polar()
。
更改
jaxlib 0.4.3 (2023 年 2 月 8 日)#
jax.Array
现在具有非阻塞is_ready()
方法,如果数组已准备好,则返回True
(另请参阅jax.block_until_ready()
)。
jax 0.4.2 (2023 年 1 月 24 日)#
重大更改
更改
jaxlib 0.4.2 (2023 年 1 月 24 日)#
更改
设置 JAX_USE_PJRT_C_API_ON_TPU=1 以启用新的 Cloud TPU 运行时,该运行时具有自动设备内存碎片整理功能。
jax 0.4.1 (2022 年 12 月 13 日)#
更改
根据 JAX 的 Python 和 NumPy 版本支持策略,已停止支持 Python 3.7。
我们引入了
jax.Array
,它是一种统一的数组类型,涵盖 JAX 中的DeviceArray
、ShardedDeviceArray
和GlobalDeviceArray
类型。jax.Array
类型有助于使并行性成为 JAX 的核心功能,简化和统一 JAX 的内部结构,并允许我们统一jit
和pjit
。在 JAX 0.4 中默认启用了jax.Array
,并对pjit
API 进行了一些重大更改。jax.Array 迁移指南可以帮助您将代码库迁移到jax.Array
。您还可以查看 分布式数组和自动并行化教程,以了解新概念。PartitionSpec
和Mesh
现在已超出实验阶段。新的 API 端点是jax.sharding.PartitionSpec
和jax.sharding.Mesh
。jax.experimental.maps.Mesh
和jax.experimental.PartitionSpec
已弃用,并将于 3 个月后删除。with_sharding_constraint
的新公共端点是jax.lax.with_sharding_constraint
。如果将 ABSL 标志与
jax.config
一起使用,则在最初从 ABSL 标志填充 JAX 配置选项后,将不再读取或写入 ABSL 标志值。此更改提高了读取jax.config
选项的性能,这些选项在 JAX 中被广泛使用。jax2tf.call_tf 函数现在对 TF 降低使用与嵌入 JAX 计算所用平台相同的第一个 TF 设备。之前,它使用的是 JAX 默认后端的第 0 个设备。
许多
jax.numpy
函数现在将其参数标记为仅限位置参数,与 NumPy 匹配。现在已弃用
jnp.msort
,原因是 numpy 1.24 中已弃用np.msort
。根据 API 兼容性 策略,它将在未来的版本中删除。可以用jnp.sort(a, axis=0)
替换它。
jaxlib 0.4.1(2022 年 12 月 13 日)#
更改
根据 JAX 的 Python 和 NumPy 版本支持策略,已停止支持 Python 3.7。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
的行为已更改为分配总 GPU 内存的 XX%,而不是以前使用当前可用 GPU 内存来计算预分配的行为。有关更多详细信息,请参阅 GPU 内存分配。已删除已弃用的方法
.block_host_until_ready()
。请改用.block_until_ready()
。
jax 0.4.0(2022 年 12 月 12 日)#
该版本已被撤回。
jaxlib 0.4.0(2022 年 12 月 12 日)#
该版本已被撤回。
jax 0.3.25(2022 年 11 月 15 日)#
更改
jax.numpy.linalg.pinv()
现在支持hermitian
选项。jax.scipy.linalg.hessenberg()
现在仅在 CPU 上受支持。需要 jaxlib > 0.3.24。添加了新函数
jax.lax.linalg.hessenberg()
、jax.lax.linalg.tridiagonal()
和jax.lax.linalg.householder_product()
。Householder 约简当前仅限于 CPU,而三对角约简仅在 CPU 和 GPU 上受支持。对于非正方形矩阵,
svd
和jax.numpy.linalg.pinv
的梯度现在计算得更经济。
重大更改
删除了
jax_experimental_name_stack
配置选项。将字符串
axis_names
参数转换为jax.experimental.maps.Mesh
构造函数中的单例元组,而不是将字符串解包为字符轴名称序列。
jaxlib 0.3.25(2022 年 11 月 15 日)#
更改
添加了对 CPU 和 GPU 上的三对角约简的支持。
添加了对 CPU 上的上 Hessenberg 约简的支持。
错误
修复了一个错误,该错误意味着 JAX 捕获的回溯中的帧在 Python 3.10+ 下被错误地映射到源行
jax 0.3.24(2022 年 11 月 4 日)#
更改
JAX 应该可以更快地导入。我们现在延迟导入 scipy,这占 JAX 导入时间的很大一部分。
设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
可用于限制写入持久缓存的缓存条目数。默认情况下,编译需要 1 秒或更长时间的计算将被缓存。如果在 TPU 上未指定顺序,则
pmap
使用的默认设备顺序现在与单进程作业的jax.devices()
匹配。以前,这两个顺序不同,这可能会导致不必要的副本或内存不足错误。要求顺序一致可以简化问题。
重大更改
jax.numpy.gradient()
现在的行为类似于jax.numpy
中的大多数其他函数,并且禁止传递列表或元组来代替数组 (#12958)jax.numpy.linalg
和jax.numpy.fft
中的函数现在统一要求输入为类数组:即,列表和元组不能代替数组使用。部分 #7737。
弃用
jax.sharding.MeshPspecSharding
已重命名为jax.sharding.NamedSharding
。jax.sharding.MeshPspecSharding
名称将在 3 个月后删除。
jaxlib 0.3.24(2022 年 11 月 4 日)#
更改
缓冲区捐赠现在可以在 CPU 上运行。这可能会破坏在 CPU 上标记缓冲区进行捐赠但依赖于未实现捐赠的代码。
jax 0.3.23(2022 年 10 月 12 日)#
更改
更新 Colab TPU 驱动程序版本以用于新的 jaxlib 版本。
jax 0.3.22(2022 年 10 月 11 日)#
更改
在 TPU 初始化中添加
JAX_PLATFORMS=tpu,cpu
作为默认设置,以便如果 TPU 无法初始化,JAX 将引发错误而不是回退到 CPU。设置JAX_PLATFORMS=''
以覆盖此行为并自动选择可用的后端(原始默认值),或者设置JAX_PLATFORMS=cpu
以始终使用 CPU,无论 TPU 是否可用。
弃用
JAX v0.3.8 中弃用的几个测试实用程序现在已从
jax.test_util
中删除。
jaxlib 0.3.22(2022 年 10 月 11 日)#
jax 0.3.21(2022 年 9 月 30 日)#
jax 0.3.20(2022 年 9 月 28 日)#
jaxlib 0.3.20(2022 年 9 月 28 日)#
jax 0.3.19(2022 年 9 月 27 日)#
修复了所需的 jaxlib 版本。
jax 0.3.18(2022 年 9 月 26 日)#
更改
提前降低和编译功能(在 #7733 中跟踪)是稳定和公共的。请参阅 概述和
jax.stages
的 API 文档。引入了
jax.Array
,旨在用于 JAX 中数组类型的isinstance
检查和类型注释。请注意,这包括对jax.numpy.ndarray
的isinstance
工作方式的一些细微更改,因为jax.numpy.ndarray
现在是jax.Array
的简单别名。
重大更改
不再将
jax._src
导入到公共jax
命名空间中。这可能会破坏正在使用 JAX 内部组件的用户。jax.soft_pmap
已被删除。请改用pjit
或xmap
。jax.soft_pmap
没有文档记录。如果它有文档记录,则会提供弃用期。
jax 0.3.17(2022 年 8 月 31 日)#
错误
修复了指数为零的
lax.pow
梯度中的极端情况问题 (#12041)
重大更改
jax.checkpoint()
,也称为jax.remat()
,在之前的版本弃用后,不再支持concrete
选项;请参阅 JEP 11830。
更改
添加了
jax.pure_callback()
,它允许从编译的函数(例如,使用jax.jit
或jax.pmap
修饰的函数)回调到纯 Python 函数。
弃用
已删除已弃用的
DeviceArray.tile()
方法。使用jax.numpy.tile()
(#11944)。DeviceArray.to_py()
已被弃用。请改用np.asarray(x)
。
jax 0.3.16#
重大更改
根据 弃用策略,已删除对 NumPy 1.19 的支持。请升级到 NumPy 1.20 或更高版本。
更改
添加了
jax.debug
,其中包括用于运行时值调试的实用程序,例如jax.debug.print()
和jax.debug.breakpoint()
。为 运行时值调试 添加了新文档
弃用
jax.mask()
jax.shapecheck()
API 已被删除。请参阅 #11557。jax.experimental.loops
已被删除。有关替代 API,请参阅 #10278。jax.tree_util.tree_multimap()
已被删除。自 JAX 版本 0.3.5 以来已被弃用,并且jax.tree_util.tree_map()
是直接替代品。删除了
jax.experimental.stax
;它长期以来都是jax.example_libraries.stax
的已弃用别名。删除了
jax.experimental.optimizers
;它长期以来都是jax.example_libraries.optimizers
的已弃用别名。jax.checkpoint()
,也称为jax.remat()
,具有默认情况下打开的新实现,这意味着旧实现已被弃用;请参阅 JEP 11830。
jax 0.3.15(2022 年 7 月 22 日)#
更改
JaxTestCase
和JaxTestLoader
已从jax.test_util
中删除。这些类自 v0.3.1 以来已被弃用 (#11248)。添加了
jax.scipy.gaussian_kde
(#11237)。JAX 数组和内置集合 (
dict
、list
、set
、tuple
) 之间的二元运算现在在所有情况下都会引发TypeError
。以前,某些情况(特别是相等和不等式)会返回与 NumPy 中类似运算不一致的布尔标量 (#11234)。作为顶级 JAX 包导入访问的几个
jax.tree_util
例程现在已被弃用,并将根据 API 兼容性 策略在未来的 JAX 版本中删除已弃用
jax.treedef_is_leaf()
,转而使用jax.tree_util.treedef_is_leaf()
已弃用
jax.tree_flatten()
,转而使用jax.tree_util.tree_flatten()
已弃用
jax.tree_leaves()
,转而使用jax.tree_util.tree_leaves()
已弃用
jax.tree_structure()
,转而使用jax.tree_util.tree_structure()
已弃用
jax.tree_transpose()
,转而使用jax.tree_util.tree_transpose()
已弃用
jax.tree_unflatten()
,转而使用jax.tree_util.tree_unflatten()
已弃用
jax.scipy.linalg.solve()
的sym_pos
参数,转而使用assume_a='pos'
,这与scipy.linalg.solve()
中的类似弃用一致。
jaxlib 0.3.15(2022 年 7 月 22 日)#
jax 0.3.14(2022 年 6 月 27 日)#
重大更改
jax.experimental.compilation_cache.initialize_cache()
不再支持max_cache_size_ bytes
,也不会将其作为输入。当平台初始化失败时,
JAX_PLATFORMS
现在会引发异常。
更改
修复了与 NumPy 1.23 的兼容性问题。
jax.numpy.linalg.slogdet()
现在接受一个可选的method
参数,该参数允许在基于 LU 分解的实现和基于 QR 分解的实现之间进行选择。jax.numpy.linalg.qr()
现在支持mode="raw"
。pickle
、copy.copy
和copy.deepcopy
现在在使用 jax 数组时具有更完整的支持 (#10659)。特别是pickle
和deepcopy
以前在使用DeviceArray
时返回np.ndarray
对象;现在返回DeviceArray
对象。对于deepcopy
,复制的数组与原始数组位于同一设备上。对于pickle
,反序列化的数组将位于默认设备上。在函数转换(即跟踪代码)中,
deepcopy
和copy
以前是空操作。现在它们使用与DeviceArray.copy()
相同的机制。现在,在跟踪数组上调用
pickle
会导致显式的ConcretizationTypeError
。
在 TPU 上,奇异值分解 (SVD) 和对称/厄米特本征分解的实现应该会快得多,尤其是对于 1000x1000 左右或以上的矩阵。两者现在都使用频谱分治算法进行本征分解 (QDWH-eig)。
jax.numpy.ldexp()
不再默默地将所有输入提升为 float64,而是将 int32 或更小尺寸的整数输入提升为 float32 (#10921)。向
jax.profiler.start_trace()
和jax.profiler.start_trace()
添加create_perfetto_link
选项。使用后,分析器将生成指向 Perfetto UI 的链接以查看跟踪。更改了
jax.profiler.start_server(...)()
的语义,以全局存储 keepalive,而不是要求用户保留对其的引用。添加了
jax.random.ball()
。添加了
jax.default_device()
。添加了一个
python -m jax.collect_profile
脚本,用于手动捕获程序跟踪,作为 TensorBoard UI 的替代方案。添加了一个
jax.named_scope
上下文管理器,它将分析器元数据添加到 Python 程序(类似于jax.named_call
)。在散布更新操作(即 :attr:
jax.numpy.ndarray.at
)中,已弃用不安全的隐式 dtype 强制转换,现在会导致FutureWarning
。在未来的版本中,这将变成一个错误。不安全的隐式强制转换的一个示例是jnp.zeros(4, dtype=int).at[0].set(1.5)
,其中1.5
以前会被默默地截断为1
。jax.experimental.compilation_cache.initialize_cache()
现在支持 gcs 存储桶路径作为输入。当系数有前导零时,
jax.numpy.roots()
在strip_zeros=False
时现在的行为更好 (#11215)。
jaxlib 0.3.14(2022 年 6 月 27 日)#
-
x86-64 Mac 轮子现在需要 Mac OS 10.14 (Mojave) 或更高版本。Mac OS 10.14 于 2018 年发布,因此这应该不是一个非常繁重的要求。
捆绑的 NCCL 版本已更新至 2.12.12,修复了一些死锁。
Python flatbuffers 包不再是 jaxlib 的依赖项。
jax 0.3.13(2022 年 5 月 16 日)#
jax 0.3.12(2022 年 5 月 15 日)#
jax 0.3.11(2022 年 5 月 15 日)#
更改
jax.lax.eigh()
现在接受一个可选的sort_eigenvalues
参数,该参数允许用户选择不使用 TPU 上的特征值排序。
弃用
jax.lax.linalg
中函数的非数组参数现在标记为仅关键字参数。作为向后兼容步骤,按位置传递仅关键字参数会产生警告,但在未来的 JAX 版本中,按位置传递仅关键字参数将失败。但是,大多数用户应该更喜欢使用jax.numpy.linalg
。已弃用作为 scipy API 的 JAX 扩展的
jax.scipy.linalg.polar_unitary()
。请改用jax.scipy.linalg.polar()
。
jax 0.3.10 (2022 年 5 月 3 日)#
jaxlib 0.3.10 (2022 年 5 月 3 日)#
jax 0.3.9 (2022 年 5 月 2 日)#
更改
增加了对 GlobalDeviceArray 的完全异步检查点支持。
jax 0.3.8 (2022 年 4 月 29 日)#
更改
TPU 上的
jax.numpy.linalg.svd()
使用 qdwh-svd 求解器。TPU 上的
jax.numpy.linalg.cond()
现在接受复数输入。TPU 上的
jax.numpy.linalg.pinv()
现在接受复数输入。TPU 上的
jax.numpy.linalg.matrix_rank()
现在接受复数输入。jax.experimental.maps.mesh
已被删除。请使用jax.experimental.maps.Mesh
。有关更多信息,请参阅 https://jax.net.cn/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh。为了匹配
scipy.linalg.qr
的行为(#10452),当mode='r'
时,jax.scipy.linalg.qr()
现在返回长度为 1 的元组而不是原始数组。jax.numpy.take_along_axis()
现在接受一个可选的mode
参数,该参数指定越界索引的行为。默认情况下,对于越界索引将返回无效值(例如,NaN)。在 JAX 的早期版本中,无效索引被限制在范围内。可以通过传递mode="clip"
来恢复之前的行为。jax.numpy.take()
现在默认为mode="fill"
,它为越界索引返回无效值(例如,NaN)。Scatter 操作,例如
x.at[...].set(...)
,现在具有"drop"
语义。这对 scatter 操作本身没有影响,但这意味着在微分时,scatter 的梯度将为越界索引产生零余切。以前,越界索引在梯度中被限制在范围内,这在数学上是不正确的。如果
jax.numpy.take_along_axis()
的索引不是整数类型,现在会引发TypeError
,这与numpy.take_along_axis()
的行为相匹配。以前,非整数索引会被静默转换为整数。如果
jax.numpy.ravel_multi_index()
的dims
参数不是整数类型,现在会引发TypeError
,这与numpy.ravel_multi_index()
的行为相匹配。以前,非整数dims
会被静默转换为整数。如果
jax.numpy.split()
的axis
参数不是整数类型,现在会引发TypeError
,这与numpy.split()
的行为相匹配。以前,非整数axis
会被静默转换为整数。如果
jax.numpy.indices()
的维度不是整数类型,现在会引发TypeError
,这与numpy.indices()
的行为相匹配。以前,非整数维度会被静默转换为整数。如果
jax.numpy.diag()
的k
参数不是整数类型,现在会引发TypeError
,这与numpy.diag()
的行为相匹配。以前,非整数k
会被静默转换为整数。
弃用
jax.test_util
中提供的许多函数和对象现在已被弃用,并且在导入时会引发警告。这包括cases_from_list
、check_close
、check_eq
、device_under_test
、format_shape_dtype_string
、rand_uniform
、skip_on_devices
、with_config
、xla_bridge
和_default_tolerance
(#10389)。这些以及先前已弃用的JaxTestCase
、JaxTestLoader
和BufferDonationTestCase
将在未来的 JAX 版本中删除。这些实用程序中的大多数可以用对标准 python 和 numpy 测试实用程序的调用来替换,例如在unittest
、absl.testing
、numpy.testing
等中找到。JAX 特定的功能(例如设备检查)可以通过使用公共 API(例如jax.devices()
)来替换。许多已弃用的实用程序仍然存在于jax._src.test_util
中,但这些不是公共 API,因此可能会在未来的版本中更改或删除,恕不另行通知。
jax 0.3.7 (2022 年 4 月 15 日)#
更改
修复了如果传递给
jax.numpy.take_along_axis()
的索引被广播时的性能问题(#10281)。jax.scipy.special.expit()
和jax.scipy.special.logit()
现在要求它们的参数是标量或 JAX 数组。它们现在还将整数参数提升为浮点数。DeviceArray.tile()
方法已弃用,因为 numpy 数组没有tile()
方法。作为替代方法,请使用jax.numpy.tile()
(#10266)。
jaxlib 0.3.7 (2022 年 4 月 15 日)#
更改
Linux wheels 现在按照
manylinux2014
标准构建,而不是manylinux2010
。
jax 0.3.6 (2022 年 4 月 12 日)#
jax 0.3.5 (2022 年 4 月 7 日)#
更改
添加了
jax.random.loggamma()
并改进了小参数值下jax.random.beta()
和jax.random.dirichlet()
的行为 (#9906)。私有
lax_numpy
子模块不再在jax.numpy
命名空间中公开 (#10029)。添加了数组创建例程
jax.numpy.frombuffer()
、jax.numpy.fromfunction()
和jax.numpy.fromstring()
(#10049)。DeviceArray.copy()
现在返回DeviceArray
而不是np.ndarray
(#10069)jax.experimental.sharded_jit
已被弃用,并将很快被删除。
弃用
jax.nn.normalize()
正在被弃用。请改用jax.nn.standardize()
(#9899)。jax.tree_util.tree_multimap()
已被弃用。请改用jax.tree_util.tree_map()
(#5746)。jax.experimental.sharded_jit
已被弃用。请改用pjit
。
jaxlib 0.3.5 (2022 年 4 月 7 日)#
jax 0.3.4 (2022 年 3 月 18 日)#
jax 0.3.3 (2022 年 3 月 17 日)#
jax 0.3.2 (2022 年 3 月 16 日)#
更改
已删除在 0.2.22 中弃用的函数
jax.ops.index_update
、jax.ops.index_add
。请改用 JAX 数组上的.at
属性,例如,x.at[idx].set(y)
。将
jax.experimental.ann.approx_*_k
移动到jax.lax
中。这些函数是jax.lax.top_k
的优化替代方案。jax.numpy.broadcast_arrays()
和jax.numpy.broadcast_to()
现在需要标量或类数组输入,如果传递列表,则会失败(#7737 的一部分)。标准 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。
pjit
现在可以在 CPU 上运行(除了之前的 TPU 和 GPU 支持)。
jaxlib 0.3.2 (2022 年 3 月 16 日)#
更改
现在,
XlaComputation.as_hlo_text()
支持通过传递布尔标志print_large_constants=True
来打印大型常量。
弃用
JAX 数组上的
.block_host_until_ready()
方法已被弃用。请改用.block_until_ready()
。
jax 0.3.1 (2022 年 2 月 18 日)#
更改
jax.test_util.JaxTestCase
和jax.test_util.JaxTestLoader
现在已被弃用。建议的替代方法是直接使用parametrized.TestCase
。对于依赖于自定义断言(例如JaxTestCase.assertAllClose()
)的测试,建议的替代方法是使用标准 numpy 测试实用程序(例如numpy.testing.assert_allclose()
),该实用程序可直接与 JAX 数组一起使用 (#9620)。jax.test_util.JaxTestCase
现在默认设置jax_numpy_rank_promotion='raise'
(#9562)。要恢复以前的行为,请使用新的jax.test_util.with_config
装饰器@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
添加了
jax.scipy.linalg.schur()
、jax.scipy.linalg.sqrtm()
、jax.scipy.signal.csd()
、jax.scipy.signal.stft()
、jax.scipy.signal.welch()
。
jax 0.3.0 (2022 年 2 月 10 日)#
jaxlib 0.3.0 (2022 年 2 月 10 日)#
更改
现在需要 Bazel 5.0.0 才能构建 jaxlib。
jaxlib 版本已提升至 0.3.0。有关说明,请参阅设计文档。
jax 0.2.28 (2022 年 2 月 1 日)#
-
如果未传递
dialect=
,则jax.jit(f).lower(...).compiler_ir()
现在默认为 MHLO 方言。jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
现在返回 MLIRir.Module
对象,而不是其字符串表示形式。
jaxlib 0.1.76 (2022 年 1 月 27 日)#
新功能
包括用于 NVidia 计算能力 8.0 GPUS(例如 A100)的预编译 SASS。删除了计算能力 6.1 的预编译 SASS,以便不增加计算能力的数量:计算能力 6.1 的 GPU 可以使用 6.0 SASS。
使用 jaxlib 0.1.76,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。
重大更改
根据弃用策略,已删除对 NumPy 1.18 的支持。请升级到受支持的 NumPy 版本。
Bug 修复
修复了通过不同路由构建的表面上相同的 pytreedef 对象无法比较为相等的问题 (#9066)。
JAX jit 缓存要求两个静态参数具有相同的类型才能命中缓存 (#9311)。
jax 0.2.27 (2022 年 1 月 18 日)#
重大更改
根据弃用策略,已删除对 NumPy 1.18 的支持。请升级到受支持的 NumPy 版本。
host_callback 原语已得到简化,以删除 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,仅点击原始值。可以通过设置
JAX_HOST_CALLBACK_AD_TRANSFORMS
环境变量或--jax_host_callback_ad_transforms
标志来获取旧行为(在有限时间内)。此外,还添加了有关如何使用 JAX 自定义 AD API 实现旧行为的文档 (#8678)。现在,无论位表示形式如何,排序都与 NumPy 中
0.0
和NaN
的行为相匹配。特别是,0.0
和-0.0
现在被视为等效,而以前-0.0
被视为小于0.0
。此外,所有NaN
表示形式现在都被视为等效,并排序到数组的末尾。以前,负NaN
值被排序到数组的前面,并且具有不同内部位表示形式的NaN
值未被视为等效,并根据这些位模式进行排序 (#9178)。jax.numpy.unique()
现在以与 NumPy 版本 1.21 及更高版本中的np.unique
相同的方式处理NaN
值:最多一个NaN
值将出现在唯一化输出中 (#9184)。
Bug 修复
host_callback 现在支持 ad_checkpoint.checkpoint (#8907)。
新功能
添加
jax.block_until_ready
({jax-issue}`#8941)添加了一个新的调试标志/环境变量
JAX_DUMP_IR_TO=/path
。如果设置,JAX 会将其为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件中。将
jax.ensure_compile_time_eval
添加到公共 api 中 (#7987)。jax2tf 现在支持一个标志 jax2tf_associative_scan_reductions,用于更改关联归约(例如 jnp.cumsum)的降低方式,使其行为类似于 CPU 和 GPU 上的 JAX(使用关联扫描)。有关更多详细信息,请参阅 jax2tf README (#9189)。
jaxlib 0.1.75 (2021 年 12 月 8 日)#
新功能
支持 python 3.10。
jax 0.2.26 (2021 年 12 月 8 日)#
jaxlib 0.1.74 (2021 年 11 月 17 日)#
启用了 GPU 之间的对等复制。以前,GPU 复制通过主机进行中转,这通常较慢。
添加了实验性 MLIR Python 绑定,供 JAX 使用。
jax 0.2.25 (2021 年 11 月 10 日)#
jax 0.2.24 (2021 年 10 月 19 日)#
jaxlib 0.1.73 (2021 年 10 月 18 日)#
现在为 jaxlib GPU
cuda11
wheels 支持多个 cuDNN 版本。cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 wheel,因为它支持其他功能。
cuDNN 8.0.5 或更高版本。
重大更改
GPU jaxlib 的安装命令如下
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (2021 年 10 月 12 日)#
重大更改
现在,
jax.pmap
的静态参数必须是可哈希的。不可哈希的静态参数长期以来在
jax.jit
上是不允许的,但它们仍然在jax.pmap
上是允许的;jax.pmap
使用对象标识比较不可哈希的静态参数。此行为是一个脚枪,因为使用对象标识比较参数会导致每次对象标识更改时都重新编译。相反,我们现在禁止不可哈希的参数:如果
jax.pmap
的用户想要通过对象标识比较静态参数,他们可以在他们的对象上定义执行此操作的__hash__
和__eq__
方法,或者将他们的对象包装在一个具有这些具有对象标识语义的操作的对象中。另一种选择是使用functools.partial
将不可哈希的静态参数封装到函数对象中。jax.util.partial
是一个意外导出,现在已被删除。请改用 Python 标准库中的functools.partial
。
弃用
函数
jax.ops.index_update
,jax.ops.index_add
等已弃用,将在未来的 JAX 版本中移除。请使用.at
JAX 数组上的属性,例如,x.at[idx].set(y)
。目前,这些函数会产生一个DeprecationWarning
。
新功能
当使用 jaxlib 0.1.72 或更高版本时,经过优化的 C++ 代码路径现在是
pmap
调度时间的默认值。可以使用--experimental_cpp_pmap
标志(或JAX_CPP_PMAP
环境变量)禁用此功能。jax.numpy.unique
现在支持可选的fill_value
参数 (#8121)
jaxlib 0.1.72 (2021 年 10 月 12 日)#
重大更改
已删除对 CUDA 10.2 和 CUDA 10.1 的支持。Jaxlib 现在支持 CUDA 11.1+。
Bug 修复
修复了 https://github.com/jax-ml/jax/issues/7461,由于 XLA 编译器内部不正确的缓冲区别名,导致所有平台上的输出错误。
jax 0.2.21 (2021 年 9 月 23 日)#
重大更改
jax.api
已被移除。作为jax.api.*
提供的函数是jax.*
中函数的别名;请使用jax.*
中的函数代替。jax.partial
和jax.lax.partial
是意外导出的,现在已被移除。请使用 Python 标准库中的functools.partial
代替。布尔标量索引现在会引发
TypeError
;以前这会在静默情况下返回错误的结果 (#7925)。现在,更多
jax.numpy
函数需要类似数组的输入,如果传递列表则会出错 (#7747 #7802 #7907)。有关此更改的基本原理的讨论,请参阅 #7737。当在诸如
jax.jit
的转换中时,jax.numpy.array
始终将其生成的数组暂存到跟踪的计算中。以前,即使在jax.jit
装饰器下,jax.numpy.array
有时也会生成一个设备上数组。此更改可能会破坏使用 JAX 数组执行形状或索引计算的代码,这些计算必须静态已知;解决方法是使用经典的 NumPy 数组执行此类计算。jnp.ndarray
现在是 JAX 数组的真正基类。特别是,这意味着对于标准 numpy 数组x
,isinstance(x, jnp.ndarray)
现在将返回False
(#7927)。
新功能
添加了
jax.numpy.insert()
实现 (#7936)。
jax 0.2.20 (2021 年 9 月 2 日)#
jaxlib 0.1.71 (2021 年 9 月 1 日)#
重大更改
已删除对 CUDA 11.0 和 CUDA 10.1 的支持。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。
jax 0.2.19 (2021 年 8 月 12 日)#
重大更改
根据 弃用策略,已删除对 NumPy 1.17 的支持。请升级到支持的 NumPy 版本。
jit
装饰器已添加到 JAX 数组上许多运算符的实现中。这加快了常见运算符(例如+
)的调度时间。对于大多数用户来说,此更改应该在很大程度上是透明的。但是,有一个已知的行为更改,即当大型整数常量直接传递给 JAX 运算符时(例如,
x + 2**40
),现在可能会产生错误。解决方法是将常量转换为显式类型(例如,np.float64(2**40)
)。
新功能
改进了 jax2tf 中对需要在数组计算中使用维度大小的操作的形状多态性的支持,例如,
jnp.mean
。 (#7317)
Bug 修复
修复了之前版本中的一些泄漏跟踪错误 (#7613)
jaxlib 0.1.70 (2021 年 8 月 9 日)#
jax 0.2.18 (2021 年 7 月 21 日)#
重大更改
根据 弃用策略,已删除对 Python 3.6 的支持。请升级到支持的 Python 版本。
最低 jaxlib 版本现在为 0.1.69。
jax.dlpack.from_dlpack()
的backend
参数已被移除。
新功能
添加了极分解 (
jax.scipy.linalg.polar()
)。
Bug 修复
收紧了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会与无效的
axis
值或空的缩减维度一起使用。 (#7196)
jaxlib 0.1.69 (2021 年 7 月 9 日)#
修复了 TFRT CPU 后端中的错误,该错误会导致不正确的结果。
jax 0.2.17 (2021 年 7 月 9 日)#
Bug 修复
默认为较旧的“stream_executor”CPU 运行时,适用于 jaxlib <= 0.1.68,以解决 #7229,由于并发问题,该问题导致 CPU 上的输出错误。
新功能
新的 SciPy 函数
jax.scipy.special.sph_harm()
。反向模式自动微分函数 (
jax.grad()
、jax.value_and_grad()
、jax.vjp()
和jax.linear_transpose()
) 支持一个参数,该参数指示如果在前向传递中广播了哪些命名轴,则应在后向传递中对这些轴求和。这使得可以在映射中以非逐例的方式使用这些 API(最初只有jax.experimental.maps.xmap()
) (#6950)。
jax 0.2.16 (2021 年 6 月 23 日)#
jax 0.2.15 (2021 年 6 月 23 日)#
jaxlib 0.1.68 (2021 年 6 月 23 日)#
Bug 修复
修复了 TFRT CPU 后端中的错误,该错误会在将 TPU 缓冲区传输到 CPU 时得到 nans。
jax 0.2.14 (2021 年 6 月 10 日)#
新功能
jax2tf.convert()
现在支持pjit
和sharded_jit
。一个新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯。
现在默认情况下,在足够新的 IPython 版本中,启用了一种使用
__tracebackhide__
的新回溯过滤模式。即使在算术运算中使用未知维度,
jax2tf.convert()
也支持形状多态性,例如,jnp.reshape(-1)
(#6827)。jax2tf.convert()
在 TF 操作中生成带有位置信息的自定义属性。jax2tf 生成的 XLA 代码与 JAX/XLA 具有相同的位置信息。新的 SciPy 函数
jax.scipy.special.lpmn()
。
Bug 修复
jaxlib 0.1.67 (2021 年 5 月 17 日)#
jaxlib 0.1.66 (2021 年 5 月 11 日)#
新功能
现在在所有 CUDA 11 版本 11.1 或更高版本上都支持 CUDA 11.1 wheels。
NVidia 现在承诺从 CUDA 11.1 开始的 CUDA 次要版本之间的兼容性。这意味着 JAX 可以发布一个与 CUDA 11.2 和 11.3 兼容的 CUDA 11.1 wheel。
不再有单独的 CUDA 11.2(或更高版本)的 jaxlib 版本;对于这些版本,请使用 CUDA 11.1 wheel (cuda111)。
Jaxlib 现在在 CUDA wheels 中捆绑了
libdevice.10.bc
。应该不需要将 JAX 指向 CUDA 安装来查找此文件。添加了对
jit()
实现的静态关键字参数的自动支持。添加了对预转换异常跟踪的支持。
对从
jit()
转换后的计算中修剪未使用的参数的初始支持。修剪仍在进行中。改进了
PyTreeDef
对象的字符串表示形式。添加了对 XLA 的可变 ReduceWindow 的支持。
Bug 修复
修复了将大量参数传递给计算时远程云 TPU 支持中的错误。
修复了一个错误,这意味着 JAX 垃圾回收没有被
jit()
转换后的函数触发。
jax 0.2.13 (2021 年 5 月 3 日)#
新功能
与 jaxlib 0.1.66 结合使用时,
jax.jit()
现在支持静态关键字参数。已添加新的static_argnames
选项以将关键字参数指定为静态。jax.nonzero()
有一个新的可选size
参数,允许它在jit
中使用 (#6501)jax.numpy.unique()
现在支持axis
参数 (#6532)。jax.experimental.host_callback.call()
现在支持pjit.pjit
(#6569)。添加了
jax.scipy.linalg.eigh_tridiagonal()
,它计算三对角矩阵的特征值。目前仅支持特征值。已更改异常中过滤和未过滤堆栈跟踪的顺序。附加到从 JAX 转换的代码抛出的异常的回溯现在被过滤,
UnfilteredStackTrace
异常包含原始跟踪作为过滤异常的__cause__
。过滤的堆栈跟踪现在也适用于 Python 3.6。如果抛出由反向模式自动微分转换的代码抛出的异常,JAX 现在尝试附加一个
JaxStackTraceBeforeTransformation
对象作为异常的__cause__
,该对象包含在前向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
重大更改
以下函数名称已更改。仍然有别名,因此这不会破坏现有代码,但别名最终将被删除,因此请更改您的代码。
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
同样,
local_devices()
的参数已从host_id
重命名为process_index
。jax.jit()
的函数以外的参数现在标记为仅关键字参数。此更改是为了防止在将参数添加到jit
时发生意外中断。
Bug 修复
jaxlib 0.1.65 (2021 年 4 月 7 日)#
jax 0.2.12 (2021 年 4 月 1 日)#
新功能
新的分析 API:
jax.profiler.start_trace()
、jax.profiler.stop_trace()
和jax.profiler.trace()
jax.lax.reduce()
现在是可微的。
重大更改
最低 jaxlib 版本现在为 0.1.64。
一些分析器 API 名称已更改。仍然有别名,因此这不会破坏现有代码,但别名最终将被删除,因此请更改您的代码。
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
Omnistaging 现在无法禁用。有关更多信息,请参阅 omnistaging。
大于最大
int64
值的 Python 整数现在将在所有情况下导致溢出,而不是在某些情况下静默转换为uint64
(#6047)。在 X64 模式之外,现在,超出
int32
可表示范围的 Python 整数将导致OverflowError
,而不是使其值被静默截断。
Bug 修复
host_callback
现在支持参数和结果中的空数组 (#6262)。jax.random.randint()
剪裁而不是包装越界限制,并且现在可以在指定 dtype 的整个范围内生成整数 (#5868)
jax 0.2.11 (2021 年 3 月 23 日)#
新功能
Bug 修复
重大更改
最低 jaxlib 版本现在为 0.1.62。
jaxlib 0.1.64 (2021 年 3 月 18 日)#
jaxlib 0.1.63 (2021 年 3 月 17 日)#
jax 0.2.10 (2021 年 3 月 5 日)#
新功能
jax.scipy.stats.chi2()
现在可以作为具有 logpdf 和 pdf 方法的分布使用。jax.scipy.stats.betabinom()
现在可以作为具有 logpmf 和 pmf 方法的分布使用。添加了
jax.experimental.jax2tf.call_tf()
以从 JAX 调用 TensorFlow 函数 (#5627) 和 README)。扩展了
lax.pad
的批处理规则以支持填充值的批处理。
Bug 修复
jax.numpy.take()
正确处理负索引 (#5768)
重大更改
JAX 的类型提升规则已调整为使类型提升更一致且对 JIT 不变。特别是,在适当的情况下,二元运算现在可能导致弱类型值。此更改的主要用户可见效果是某些操作导致输出的精度与以前不同;例如,表达式
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
之前返回一个float64
数组,现在返回一个bfloat16
数组。JAX 的类型提升行为在 类型提升语义 中描述。jax.numpy.linspace()
现在计算整数值的下限,即四舍五入到 -inf 而不是 0。此更改是为了匹配 NumPy 1.20.0。jax.numpy.i0()
不再接受复数。以前,该函数计算复数参数的绝对值。此更改是为了匹配 NumPy 1.20.0 的语义。一些
jax.numpy
函数不再接受元组或列表来代替数组参数:jax.numpy.pad()
、:funcjax.numpy.ravel
、jax.numpy.repeat()
、jax.numpy.reshape()
。一般来说,jax.numpy
函数应与标量或数组参数一起使用。
jaxlib 0.1.62 (2021 年 3 月 9 日)#
新功能
默认情况下,现在构建 jaxlib wheels 以要求在 x86-64 机器上使用 AVX 指令。如果你想在不支持 AVX 的机器上使用 JAX,你可以使用
--target_cpu_features
标志从源代码构建 jaxlib 到build.py
。--target_cpu_features
也取代了--enable_march_native
。
jaxlib 0.1.61 (2021 年 2 月 12 日)#
jaxlib 0.1.60 (2021 年 2 月 3 日)#
Bug 修复
修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏。内存泄漏存在于 jaxlib 版本 0.1.58 和 0.1.59 中。
现在,
bool
、int8
和uint8
被认为是安全地转换为bfloat16
NumPy 扩展类型的。
jax 0.2.9 (2021 年 1 月 26 日)#
新功能
扩展了
jax.experimental.loops
模块,以支持 pytrees。改进了错误检查和错误消息。添加了
jax.experimental.enable_x64()
和jax.experimental.disable_x64()
。 这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
重大更改
jax.ops.segment_sum()
现在会丢弃超出范围的段 ID,而不是将它们包装到段 ID 空间中。 这是出于性能原因而完成的。
jaxlib 0.1.59 (2021 年 1 月 15 日)#
jax 0.2.8 (2021 年 1 月 12 日)#
新功能
添加了
jax.closure_convert()
以用于高阶自定义导数函数。 (#5244)添加了
jax.experimental.host_callback.call()
以在主机上调用自定义 Python 函数并将结果返回给设备计算。 (#5243)
Bug 修复
重大更改
jax.numpy.pad
现在采用关键字参数。 位置参数constant_values
已被删除。 此外,传递不支持的关键字参数会引发错误。jax.experimental.host_callback.id_tap()
的更改 (#5243)删除了对
jax.experimental.host_callback.id_tap()
的kwargs
的支持。 (此支持已被弃用几个月。)将
jax.experimental.host_callback.id_print()
的元组打印更改为使用“(”而不是“[“。更改了存在 JVP 时的
jax.experimental.host_callback.id_print()
,以打印原始和切线的对。 以前,原始和切线有两个单独的打印操作。host_callback.outfeed_receiver
已被删除(没有必要,并且几个月前已被弃用)。
新功能
用于调试
inf
的新标志,类似于NaN
的标志 (#5224)。
jax 0.2.7 (2020 年 12 月 4 日)#
新功能
添加
jax.device_put_replicated
为
jax.experimental.sharded_jit
添加多主机支持添加对区分由
jax.numpy.linalg.eig
计算的特征值的支持添加对在 Windows 平台上构建的支持
在
jax.pmap
中添加对通用 in_axes 和 out_axes 的支持为
jax.numpy.linalg.slogdet
添加复数支持
Bug 修复
修复了零处
jax.numpy.sinc
的高于二阶的导数修复了转置规则中一些难以命中的关于符号零的错误
重大更改
jax.experimental.optix
已被删除,以支持独立的optax
Python 包。使用非元组序列对 JAX 数组进行索引现在会引发
TypeError
。 自 v1.16 起,Numpy 中已弃用此类索引,自 v0.2.4 起,JAX 中已弃用此类索引。 请参阅 #4564。
jax 0.2.6 (2020 年 11 月 18 日)#
新特性
为 jax.experimental.jax2tf 转换器添加了对形状多态跟踪的支持。 请参阅 README.md。
重大更改清理
在 jax.jit 和 xla_computation 的不可哈希静态参数上引发错误。 请参阅 cb48f42。
提高类型提升行为的一致性 (#4744)
将复数 Python 标量添加到 JAX 浮点数尊重 JAX 浮点数的精度。 例如,
jnp.float32(1) + 1j
现在返回complex64
,而之前它返回complex128
。涉及 uint64、有符号 int 和第三种类型的 3 个或更多术语的类型提升结果现在独立于参数的顺序。 例如:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
都返回float16
,而之前第一个返回float64
,第二个返回float16
。
(未记录的)
jax.lax_linalg
线性代数模块的内容现在公开公开为jax.lax.linalg
。jax.random.PRNGKey
现在在 JIT 编译内外产生相同的结果 (#4877)。 这需要在一些特定情况下更改给定种子的结果使用
jax_enable_x64=False
,作为 Python 整数传递的负种子现在在 JIT 模式之外返回不同的结果。 例如,jax.random.PRNGKey(-1)
之前返回[4294967295, 4294967295]
,现在返回[0, 4294967295]
。 这与 JIT 中的行为相匹配。JIT 外部
int64
可表示范围之外的种子现在会导致OverflowError
而不是TypeError
。 这与 JIT 中的行为相匹配。
要恢复之前为 JIT 外部
jax_enable_x64=False
的负整数返回的密钥,您可以使用key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray 现在在尝试访问其值时如果已被删除,则会引发
RuntimeError
而不是ValueError
。
jaxlib 0.1.58 (大约 2021 年 1 月 12 日)#
修复了一个错误,该错误意味着 JAX 有时会返回特定于平台的类型(例如,
np.cint
)而不是标准类型(例如,np.int32
)。 (#4903)修复了常量折叠某些 int16 操作时发生的崩溃。 (#4971)
向
pytree.flatten()
添加了一个is_leaf
谓词。
jaxlib 0.1.57 (2020 年 11 月 12 日)#
修复了 GPU Wheels 中的 manylinux2010 合规性问题。
将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
修复了一个错误,该错误导致 bfloat16 值的哈希未正确初始化并且可能会更改 (#4651)。
添加了在将数组传递给 DLPack 时保留所有权的支持 (#4636)。
修复了批处理三角求解的大小大于 128 但不是 128 的倍数时出现的错误。
修复了在多个 GPU 上执行并发 FFT 时出现的错误 (#3518)。
修复了分析器中缺少工具的错误 (#4427)。
放弃了对 CUDA 10.0 的支持。
jax 0.2.5 (2020 年 10 月 27 日)#
改进
确保
check_jaxpr
不执行 FLOPS。 请参阅 #4650。扩展了 jax2tf 转换的 JAX 原语集。 请参阅 primitives_with_limited_support.md。
jax 0.2.4 (2020 年 10 月 19 日)#
jaxlib 0.1.56 (2020 年 10 月 14 日)#
jax 0.2.3 (2020 年 10 月 14 日)#
这么快进行另一次发布的原因是我们暂时回滚了一个新的 jit 快速通道,同时我们研究了性能下降问题
jax 0.2.2 (2020 年 10 月 13 日)#
jax 0.2.1 (2020 年 10 月 6 日)#
改进
作为 omnistaging 的一项优势,即使计算中未使用
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
的结果,也会执行 host_callback 函数(按程序顺序)。
jax (0.2.0) (2020 年 9 月 23 日)#
改进
默认情况下启用 Omnistaging。 请参阅 #3370 和 omnistaging
jax (0.1.77) (2020 年 9 月 15 日)#
重大更改
jax.experimental.host_callback.id_tap()
的新简化接口 (#4101)
jaxlib 0.1.55 (2020 年 9 月 8 日)#
更新 XLA
修复 DLPackManagedTensorToBuffer 中的错误 (#4196)
jax 0.1.76 (2020 年 9 月 8 日)#
jax 0.1.75 (2020 年 7 月 30 日)#
Bug 修复
使 jnp.abs() 适用于无符号输入 (#3914)
改进
在标志后面添加了“Omnistaging”行为,默认情况下禁用 (#3370)
jax 0.1.74 (2020 年 7 月 29 日)#
新特性
BFGS (#3101)
对半精度算法的 TPU 支持 (#3878)
Bug 修复
防止一些意外的 dtype 警告 (#3874)
修复了自定义导数中的多线程错误 (#3845, #3869)
改进
更快的 searchsorted 实现 (#3873)
jax.numpy 排序算法的更好测试覆盖率 (#3836)
jaxlib 0.1.52 (2020 年 7 月 22 日)#
更新 XLA。
jax 0.1.73 (2020 年 7 月 22 日)#
最低 jaxlib 版本现在为 0.1.51。
新特性
jax.image.resize. (#3703)
hfft 和 ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
和scan
原语支持在降低到 XLA 时用于循环展开的unroll
参数 (#3738)。
Bug 修复
修复了缩减重复轴错误 (#3618)
修复了大小为 0 的输入尺寸的 lax.pad 的形状规则。 (#3608)
使 psum 转置处理零余切 (#3653)
修复了在对大小为 0 的轴进行reduce-prod 的 JVP 时出现的形状错误。 (#3729)
支持通过 jax.lax.all_to_all 进行区分 (#3733)
解决了 jax.scipy.special.zeta 中的 nan 问题 (#3777)
改进
对 jax2tf 的许多改进
使用单通道可变参数缩减重新实现 argmin/argmax。 (#3611)
默认情况下启用 XLA SPMD 分区。 (#3151)
添加了对 0d 转置卷积的支持 (#3643)
使 LU 梯度适用于低秩矩阵 (#3610)
支持 jet 中的 multiple_results 和自定义 JVP (#3657)
推广了 reduce-window 填充以支持 (lo, hi) 对。 (#3728)
在 CPU 和 GPU 上实现复数卷积。 (#3735)
使 jnp.take 适用于空数组的空切片。 (#3751)
放宽了 dot_general 的维度排序规则。 (#3778)
为 GPU 启用缓冲区捐赠。 (#3800)
添加了对基本膨胀和窗口膨胀的支持以减少窗口操作... (#3803)
jaxlib 0.1.51 (2020 年 7 月 2 日)#
更新 XLA。
为 host_callback 添加了新的运行时支持。
jax 0.1.72 (2020 年 6 月 28 日)#
jax 0.1.71 (2020 年 6 月 25 日)#
jaxlib 0.1.50 (2020 年 6 月 25 日)#
添加了对 CUDA 11.0 的支持。
放弃了对 CUDA 9.2 的支持(我们只维护对最近四个 CUDA 版本的支持。)
更新 XLA。
jaxlib 0.1.49 (2020 年 6 月 19 日)#
Bug 修复
修复了可能导致编译缓慢的构建问题 (tensorflow/tensorflow)
jaxlib 0.1.48 (2020 年 6 月 12 日)#
新功能
添加了对快速回溯收集的支持。
添加了对设备上堆分析的初步支持。
为
bfloat16
类型实现了np.nextafter
。CPU 和 GPU 上 FFT 的 Complex128 支持。
Bug 修复
改进了 GPU 上 float64
tanh
的精度。GPU 上的 float64 散射速度更快。
CPU 上的复数矩阵乘法应该更快。
CPU 上的稳定排序现在实际上应该是稳定的。
CPU 后端中的并发错误修复。
jax 0.1.70 (2020 年 6 月 8 日)#
jax 0.1.69 (2020 年 6 月 3 日)#
jax 0.1.68 (2020 年 5 月 21 日)#
jax 0.1.67 (2020 年 5 月 12 日)#
新功能
支持使用
axis_index_groups
在 pmapped 轴的子集上进行缩减 #2382。从编译代码打印和调用主机端 Python 函数的实验性支持。 请参阅 id_print 和 id_tap (#3006)。
显著变化
已收紧从
jax.numpy
导出的名称的可见性。 这可能会破坏利用先前意外导出的名称的代码。
jaxlib 0.1.47 (2020 年 5 月 8 日)#
修复了 outfeed 的崩溃。
jax 0.1.66 (2020 年 5 月 5 日)#
jaxlib 0.1.46 (2020 年 5 月 5 日)#
修复了 Mac OS X 上线性代数函数的崩溃问题 (#432)。
修复了由操作系统或虚拟机管理程序禁用 AVX512 指令时使用 AVX512 指令导致的非法指令崩溃 (#2906)。
jax 0.1.65 (2020 年 4 月 30 日)#
jaxlib 0.1.45 (2020 年 4 月 21 日)#
修复了段错误:#2755
将 Sort HLO 上的 is_stable 选项通过管道传输到 Python。
jax 0.1.64 (2020 年 4 月 21 日)#
新功能
为函数式索引更新添加了语法糖 #2684。
添加了
jax.numpy.unique()
#2760。添加了
jax.numpy.rint()
#2724。添加了
jax.numpy.rint()
#2724。为
jax.experimental.jet()
添加了更多原始规则。
Bug 修复
更好的错误
改进了
lax.while_loop()
的反向模式微分的错误消息 #2129。
jaxlib 0.1.44 (2020 年 4 月 16 日)#
修复了一个错误,如果存在不同型号的多个 GPU,JAX 将只会编译适合第一个 GPU 的程序。
batch_group_count
卷积的 Bugfix。为更多 GPU 版本添加了预编译的 SASS,以避免启动 PTX 编译挂起。
jax 0.1.63 (2020 年 4 月 12 日)#
从 #2026 添加了
jax.custom_jvp
和jax.custom_vjp
,请参阅 教程笔记本。 弃用了jax.custom_transforms
并将其从文档中删除(尽管它仍然有效)。添加了
scipy.sparse.linalg.cg
#2566。更改了 Tracers 的打印方式,以显示更多用于调试的有用信息 #2591。
使
jax.numpy.isclose
正确处理nan
和inf
#2501。为
jax.experimental.jet
添加了几个新规则 #2537。修复了未提供
scale
/center
时的jax.experimental.stax.BatchNorm
。修复了
jax.numpy.einsum
中一些丢失的广播情况 #2512。根据并行前缀扫描 #2596 实现了
jax.numpy.cumsum
和jax.numpy.cumprod
,并使reduce_prod
可微至任意阶 #2597。将
batch_group_count
添加到conv_general_dilated
#2635。为
test_util.check_grads
添加了文档字符串 #2656。添加了
callback_transform
#2665。实现了
rollaxis
、convolve
/correlate
1d 和 2d、copysign
、trunc
、roots
和quantile
/percentile
插值选项。
jaxlib 0.1.43 (2020 年 3 月 31 日)#
修复了 GPU 上 Resnet-50 的性能回归。
jax 0.1.62 (2020 年 3 月 21 日)#
JAX 已放弃对 Python 3.5 的支持。 请升级到 Python 3.6 或更高版本。
删除了内部函数
lax._safe_mul
,该函数实现了约定0. * nan == 0.
。 此更改意味着某些程序在微分时会产生 nan,而以前它们会产生正确的值,尽管它确保为其他程序生成 nan 而不是默默地不正确的结果。 有关详细信息,请参阅 #2447 和 #1052。添加了一个
all_gather
并行便捷函数。核心代码中更多类型注释。
jaxlib 0.1.42 (2020 年 3 月 19 日)#
jaxlib 0.1.41 由于 API 不兼容而破坏了云 TPU 支持。 此版本再次修复了它。
JAX 已放弃对 Python 3.5 的支持。 请升级到 Python 3.6 或更高版本。
jax 0.1.61 (2020 年 3 月 17 日)#
修复了 Python 3.5 支持。 这将是最后一个支持 Python 3.5 的 JAX 或 jaxlib 版本。
jax 0.1.60 (2020 年 3 月 17 日)#
新功能
jax.pmap()
具有static_broadcast_argnums
参数,该参数允许用户指定应被视为编译时常量并应广播到所有设备的参数。它的工作方式类似于jax.jit()
中的static_argnums
。改进了当追踪器错误地保存在全局状态时的错误消息。
添加了
jax.nn.one_hot()
实用函数。添加了
jax.experimental.jet
以实现指数级更快的高阶自动微分。为
jax.lax.broadcast_in_dim()
的参数添加了更多正确性检查。
最低 jaxlib 版本现在为 0.1.41。
jaxlib 0.1.40 (2020 年 3 月 4 日)#
在 Jaxlib 中添加了对 TensorFlow Profiler 的实验性支持,该支持允许从 TensorBoard 追踪 CPU 和 GPU 计算。
包括通过 NCCL 进行通信的多主机 GPU 计算的原型支持。
提高了 GPU 上 NCCL 集体操作的性能。
添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。
支持在 XLA 编译时已知的设备分配。
jax 0.1.59 (2020 年 2 月 11 日)#
重大更改
最低 jaxlib 版本现在为 0.1.38。
通过删除
Jaxpr.freevars
和Jaxpr.bound_subjaxprs
简化了Jaxpr
。调用原语(xla_call
、xla_pmap
、sharded_call
和remat_call
)获得了一个新的参数call_jaxpr
,其中包含一个完全封闭的(没有constvars
)jaxpr。此外,向原语添加了一个新字段call_primitive
。
新功能
lax.cond
的反向模式自动微分(例如,grad
),使其现在在两种模式下都可微分(#2091)JAX 现在支持 DLPack,它允许以零拷贝的方式与其他库(例如 PyTorch)共享 CPU 和 GPU 数组。
JAX GPU DeviceArrays 现在支持
__cuda_array_interface__
,这是另一种零拷贝协议,用于与其他库(例如 CuPy 和 Numba)共享 GPU 数组。JAX CPU 设备缓冲区现在实现了 Python 缓冲区协议,该协议允许 JAX 和 NumPy 之间进行零拷贝缓冲区共享。
添加了 JAX_SKIP_SLOW_TESTS 环境变量以跳过已知缓慢的测试。
jaxlib 0.1.39 (2020 年 2 月 11 日)#
更新 XLA。
jaxlib 0.1.38 (2020 年 1 月 29 日)#
不再支持 CUDA 9.0。
现在默认构建 CUDA 10.2 wheels。
jax 0.1.58 (2020 年 1 月 28 日)#
重大更改
JAX 已放弃对 Python 2 的支持,因为 Python 2 已于 2020 年 1 月 1 日达到其生命周期结束。请更新到 Python 3.5 或更高版本。
新功能
while 循环的前向模式自动微分 (
jvp
)(#1980)
新的 NumPy 和 SciPy 函数
GPU 上的批处理 Cholesky 分解现在使用更高效的批处理内核。
值得注意的错误修复#
随着 Python 3 的升级,JAX 不再依赖
fastcache
,这应该有助于安装。