更改日志#
建议在此处查看:此处。有关 Pallas API 的特定更改,请参阅Pallas Changelog。
JAX 遵循基于工作量的版本控制;有关此以及 JAX API 兼容性策略的讨论,请参阅API 兼容性。有关 Python 和 NumPy 版本支持策略,请参阅Python 和 NumPy 版本支持策略。
未发布#
更改
jax.lax.linalg.eigh()现在接受一个implementation参数来选择 QR (CPU/GPU)、Jacobi (GPU/TPU) 和 QDWH (TPU) 实现。EighImplementation枚举从jax.lax.linalg公开导出。
JAX 0.8.0 (2025年10月15日)#
重大更改
JAX 将默认的
jax.pmap实现更改为基于jax.jit和jax.shard_map实现的。jax.pmap处于维护模式,我们鼓励所有新代码直接使用jax.shard_map。有关更多信息,请参阅迁移指南。jax.experimental.shard_map.shard_map的auto=参数已被移除。这意味着jax.experimental.shard_map.shard_map不再支持嵌套。如果要嵌套 shard_map 调用,请使用jax.shard_map。JAX 不再允许将支持
__jax_array__的对象直接传递给(例如)jit编译的函数。请先在它们上调用jax.numpy.asarray。jax.numpy.cov()现在返回 NaN 用于空数组(#32305),并且在单行设计矩阵方面与 NumPy 2.2 的行为一致(#32308)。JAX 不再接受
Array值,而是在需要dtype值的地方。请先调用这些值的.dtype。已移除已弃用的函数
jax.interpreters.mlir.custom_call()。已移除
jax.util、jax.extend.ffi和jax.experimental.host_callback模块。这些模块中的所有公共 API 在 v0.7.0 或更早版本中已被弃用并移除。已移除已弃用的符号
jax.custom_derivatives.custom_jvp_call_jaxpr_p。jax.experimental.multihost_utils.process_allgather当输入是 jax.Array 且非完全可寻址且tiled=False时会引发错误。要修复此问题,请将tiled=True传递给您的process_allgather调用。从
jax.experimental.compilation_cache中,已移除已弃用的符号is_initialized和initialize_cache。已移除已弃用的函数
jax.interpreters.xla.canonicalize_dtype()。jaxlib.hlo_helpers已被移除。请使用jax.ffi代替。选项
jax_cpu_enable_gloo_collectives已被移除。请使用jax_cpu_collectives_implementation代替。已移除
jax.numpy.percentile()和jax.numpy.quantile()中先前已弃用的interpolation参数;请使用method代替。已移除 JAX 内部的
for_loop原语。现在jax.lax.fori_loop()直接支持其功能,即在循环体内读写 ref。如果您需要帮助更新代码,请提交 bug。jax.numpy.trimzeros()现在会因非一维输入而报错。对
jax.numpy.sum()和其他归约操作的where参数现在要求为布尔值。非布尔值自 JAX v0.5.0 起已导致DeprecationWarning。已移除 {mod}
jax.dlpack, {mod}jax.errors, {mod}jax.lib.xla_bridge, {mod}jax.lib.xla_client, 和 {mod}jax.lib.xla_extension中的已弃用函数。jax.interpreters.mlir.dense_bool_array已移除。请使用 MLIR API 来构造属性。
更改
jax.numpy.linalg.eig()现在返回一个命名元组(带有eigenvalues和eigenvectors属性),而不是一个普通元组。jax.grad()和jax.vjp()现在始终将主元舍入到float32,除非启用了float64模式。jax.dlpack.from_dlpack()现在接受具有非默认布局的数组,例如转置的数组。NVIDIA GPU 上的默认非对称特征值分解现在使用 cusolver。通过
jax.lax.linalg.eig()的新implementation参数,仍然可以使用 magma 和 LAPACK 实现(#27265)。use_magma参数现在已弃用,改用implementation。jax.numpy.trim_zeros()现在遵循 NumPy 2.2,支持多维输入。
弃用
jax.experimental.enable_x64()和jax.experimental.disable_x64()已被弃用,改用新的非实验性上下文管理器jax.enable_x64()。jax.experimental.shard_map.shard_map()已被弃用;将来请使用jax.shard_map()。jax.experimental.pjit.pjit()已被弃用;将来请使用jax.jit()。
JAX 0.7.2 (2025年9月16日)#
重大更改
jax.dlpack.from_dlpack()不再接受 DLPack 胶囊。此行为已被弃用并已移除。该函数必须使用实现__dlpack__和__dlpack_device__的数组进行调用。
更改
现在支持的最低 NumPy 版本为 2.0。由于 NumPy 2.0 支持需要 SciPy 1.13,因此现在支持的最低 SciPy 版本为 1.13。
JAX 现在在其内部 jaxpr 表示中将常量表示为
TypedNdArray,这是一个私有的 JAX 类型,可模拟numpy.ndarray。此类型可能会通过custom_jvp规则等暴露给用户,并可能破坏使用isinstance(x, np.ndarray)的代码。如果这破坏了您的代码,您可以使用np.asarray(x)将这些数组转换为经典 NumPy 数组。
错误修复
arr.view(dtype=None)现在返回不变的数组,与 NumPy 的语义匹配。之前它返回具有 float 类型的数组。jax.random.randint现在为 8 位和 16 位整数类型生成更均匀分布的分布(#27742)。要恢复先前的有偏行为,您可以暂时将jax_safer_randint配置设置为False,但请注意,这是一个临时配置,将在未来版本中移除。
弃用
jax2tf.convert的enable_xla和native_serialization参数已被弃用,将在 JAX 的未来版本中移除。这些参数用于 jax2tf 与非原生序列化,而这部分功能已被移除。设置配置状态
jax_pmap_no_rank_reduction为False已被弃用。默认情况下,jax_pmap_no_rank_reduction将设置为True,并且jax.pmap分片不会降低其秩,保持与其包含数组相同的秩。
JAX 0.7.1 (2025年8月20日)#
新功能
JAX 现在提供 Python 3.14 和 3.14t 轮子。
JAX 现在在 Mac 上提供 Python 3.13t 和 3.14t 轮子。之前我们只在 Linux 上提供免费线程构建。
更改
公开了
jax.set_mesh,它充当全局设置器和上下文管理器。移除了jax.sharding.use_mesh,改用jax.set_mesh。JAX 现在使用 CUDA 12.9 构建。所有 CUDA 12.1 或更高版本仍然受支持。
jax.lax.dot()现在通过可选的dimension_numbers参数实现通用点积。
弃用
尝试导入
jax.experimental.host_callback现在会导致DeprecationWarning,并将在 JAX v0.8.0 中导致ImportError。自 JAX 0.4.35 版本以来,其 API 已引发NotImplementedError。在
jax.lax.dot()中,按位置传递precision和preferred_element_type参数已被弃用。请改为通过显式关键字传递它们。已从
jax.interpreters.ad、jax.interpreters.batching和jax.interpreters.partial_eval中弃用了几十个内部 API;它们很少或根本不在 JAX 本身之外使用,而且大多数都被弃用而没有公共替换。
JAX 0.7.0 (2025年7月22日)#
新功能
添加了
jax.P,它是jax.sharding.PartitionSpec的别名。jax.numpy.ndarray.at索引方法现在支持wrap_negative_indices参数,该参数默认为True以匹配当前行为(#29434)。
重大更改
JAX 正在默认从 GSPMD 迁移到 Shardy。有关更多信息,请参阅迁移指南。
JAX 自动微分正在默认切换为使用直接线性化(而不是通过 JVP 和部分求值实现线性化)。有关更多信息,请参阅迁移指南。
jax.stages.OutInfo已被jax.ShapeDtypeStruct替换。jax.jit()现在要求fun按位置传递,其他参数按关键字传递。否则将在 v0.7.x 中引发错误。这在 v0.6.x 中引发了 DeprecationWarning。最低 Python 版本现在是 3.11。3.11 将保持最低支持版本,直到 2026 年 7 月。
布局 API 重命名
Layout、.layout、.input_layouts和.output_layouts已重命名为Format、.format、.input_formats和.output_formatsDeviceLocalLayout、.device_local_layout已重命名为Layout和.layout
jax.experimental.shard模块已被删除,所有 API 已移至jax.sharding端点。因此,请使用jax.sharding.reshard、jax.sharding.auto_axes和jax.sharding.explicit_axes而不是它们的实验性端点。lax.infeed和lax.outfeed已被移除,此前已在 JAX 0.6 中弃用。Device对象上的transfer_to_infeed和transfer_from_outfeed方法也被移除。jax.extend.core.primitives.pjit_p原语已重命名为jit_p,其name属性已从"pjit"更改为"jit"。这会影响 jaxprs 的字符串表示。此原语不再从jax.experimental.pjit模块导出。(未记录的)函数
jax.extend.backend.add_clear_backends_callback已被移除。用户应改用jax.extend.backend.register_backend_cache。out_sharding参数已添加到x.at[y].set和x.at[y].add。先前的传播操作数分片行为已被移除。如果您需要保留先前的行为,请使用x.at[y].set/add(z, out_sharding=jax.typeof(x).sharding)以便散布操作需要集合通信。
弃用
jax.dlpack.SUPPORTED_DTYPES已被弃用;请使用新的jax.dlpack.is_supported_dtype()函数。jax.scipy.special.sph_harm()已被弃用,类似于 SciPy 中的类似弃用;请改用jax.scipy.special.sph_harm_y()。从
jax.interpreters.xla中,先前已弃用的符号abstractify和pytype_aval_mappings已被移除。jax.interpreters.xla.canonicalize_dtype()已被弃用。要规范化 dtype,请优先使用jax.dtypes.canonicalize_dtype()。要检查对象是否为有效的 jax 输入,请优先使用jax.core.valid_jaxtype()。从
jax.core中,先前已弃用的符号AxisName、ConcretizationTypeError、axis_frame、call_p、closed_call_p、get_type、trace_state_clean、typematch和typecheck已被移除。从
jax.lib.xla_client中,先前已弃用的符号DeviceAssignment、get_topology_for_devices和mlir_api_version已被移除。jax.extend.ffi已被移除,此前已在 v0.5.0 中弃用。请使用jax.ffi代替。jax.lib.xla_bridge.get_compile_options()已被弃用,并由jax.extend.backend.get_compile_options()替换。
JAX 0.6.2 (2025年6月17日)#
新功能
添加了
jax.tree.broadcast(),它实现了一个 pytree 前缀广播助手。
更改
最低 NumPy 版本为 1.26,最低 SciPy 版本为 1.12。
JAX 0.6.1 (2025年5月21日)#
新功能
添加了
jax.lax.axis_size(),它返回给定名称的映射轴的大小。
更改
CUDA 包依赖项的版本检查被重新启用,之前在一个之前的版本中被意外禁用。
JAX nightly 包现在已发布到 artifact registry。要安装这些包,请参阅JAX 安装指南。
jax.sharding.PartitionSpec不再继承自元组。jax.ShapeDtypeStruct现在是不可变的。请使用.update方法更新您的ShapeDtypeStruct,而不是进行原地更新。
弃用
jax.custom_derivatives.custom_jvp_call_jaxpr_p已被弃用,并将在 JAX v0.7.0 中移除。
JAX 0.6.0 (2025年4月16日)#
重大更改
jax.numpy.array()不再接受None。此行为自 2023 年 11 月起已被弃用,现已移除。移除了
config.jax_data_dependent_tracing_fallback配置选项,该选项是在 v0.4.36 中临时添加的,允许用户选择退出新的“无堆栈”跟踪机制。移除了
config.jax_eager_pmap配置选项。禁止对
jax.jit的结果调用lower和traceAOT API,如果后续应用了其他包装器。之前可以这样做,但会默默忽略包装器。解决方法是将jax.jit作为最后一个包装器应用,jax.pmap同理。请参阅#27873。jax的cuda12_pip附加包已移除;请改为使用pip install jax[cuda12]。
更改
最低 CuDNN 版本为 v9.8。
JAX 现在使用 CUDA 12.8 构建。所有 CUDA 12.1 或更高版本仍然受支持。
JAX 包附加包现已更新为使用破折号而不是下划线,以符合 PEP 685。例如,如果您之前使用
pip install jax[cuda12_local]来安装 JAX,请改用pip install jax[cuda12-local]。jax.jit()现在要求fun按位置传递,其他参数按关键字传递。否则将在 v0.6.X 中引发 DeprecationWarning,并在 v0.7.X 中开始引发错误。
弃用
jax.tree_util.build_tree()已被弃用。请改用jax.tree.unflatten()。使用 XLA 的 FFI 实现了一系列 CPU 和 GPU 设备的主机回调处理程序,并移除了使用 XLA 自定义调用的现有 CPU/GPU 处理程序。
jax.lib.xla_extension中的所有 API 现在都已弃用。jax.interpreters.mlir.hlo和jax.interpreters.mlir.func_dialect,这些是意外导出的,已被移除。如果需要,可以从jax.extend.mlir获取。jax.interpreters.mlir.custom_call已被弃用。应使用jax.ffi提供的 API。使用内联参数的
jax.ffi.ffi_call()的已弃用用法不再受支持。ffi_call()现在无条件返回一个可调用对象。jax.lib.xla_client中的以下导出已弃用:get_topology_for_devices、heap_profile、mlir_api_version、Client、CompileOptions、DeviceAssignment、Frame、HloSharding、OpSharding、Traceback。jax.util中的以下内部 API 已弃用:HashableFunction、as_hashable_function、cache、safe_map、safe_zip、split_dict、split_list、split_list_checked、split_merge、subvals、toposort、unzip2、wrap_name和wraps。jax.dlpack.to_dlpack已被弃用。通常可以将 JAXArray直接传递给另一个框架的from_dlpack函数。如果您需要to_dlpack的功能,请使用数组的__dlpack__属性。jax.lax.infeed、jax.lax.infeed_p、jax.lax.outfeed和jax.lax.outfeed_p已被弃用,并将在 JAX v0.7.0 中移除。已移除几个先前已弃用的 API,包括
来自
jax.lib.xla_client:ArrayImpl、FftType、PaddingType、PrimitiveType、XlaBuilder、dtype_to_etype、ops、register_custom_call_target、shape_from_pyval、Shape、XlaComputation。来自
jax.lib.xla_extension:ArrayImpl、XlaRuntimeError。来自
jax:jax.treedef_is_leaf、jax.tree_flatten、jax.tree_map、jax.tree_leaves、jax.tree_structure、jax.tree_transpose和jax.tree_unflatten。替换项可以在jax.tree或jax.tree_util中找到。来自
jax.core:AxisSize、ClosedJaxpr、EvalTrace、InDBIdx、InputType、Jaxpr、JaxprEqn、Literal、MapPrimitive、OpaqueTraceState、OutDBIdx、Primitive、Token、TRACER_LEAK_DEBUGGER_WARNING、Var、concrete_aval、dedup_referents、escaped_tracer_error、extend_axis_env_nd、full_lower、get_referent、jaxpr_as_fun、join_effects、lattice_join、leaked_tracer_error、maybe_find_leaked_tracers、raise_to_shaped、raise_to_shaped_mappings、reset_trace_state、str_eqn_compact、substitute_vars_in_output_ty、typecompat和used_axis_names_jaxpr。大多数没有公共替换项,但少数可以在jax.extend.core中找到。vectorized参数到pure_callback()和ffi_call()。请改用vmap_method参数。
jax 0.5.3 (2025年3月19日)#
新功能
在
jax.lax.dynamic_slice()、jax.lax.dynamic_update_slice()和相关函数中添加了一个allow_negative_indices选项。默认值为 true,与当前行为一致。如果设置为 false,JAX 不需要发出代码来裁剪负索引,这可以减小代码大小。在
jax.random.categorical()中添加了一个replace选项,以启用不放回抽样。
jax 0.5.2 (2025年3月4日)#
0.5.1 的补丁版本
错误修复
修复了 TPU 指标日志记录和
tpu-info的 bug,这些在 0.5.1 中是损坏的
jax 0.5.1 (2025年2月24日)#
重大更改
jit 跟踪缓存现在以输入的 NamedShardings 为键。之前,跟踪缓存根本不包含分片信息(尽管后续的 jit 缓存如降低和编译缓存包含),因此不同类型的两个等效分片不会重新跟踪,但现在它们会。例如
@jax.jit def f(x): return x # inp1.sharding is of type SingleDeviceSharding inp1 = jnp.arange(8) f(inp1) mesh = jax.make_mesh((1,), ('x',)) # inp2.sharding is of type NamedSharding inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x'))) f(inp2) # tracing cache miss
在上面的示例中,调用
f(inp1)然后调用f(inp2)将导致跟踪缓存未命中,因为在跟踪过程中抽象值上的分片已更改。
新功能
在
jax.experimental.custom_dce.custom_dce()中添加了一个实验性装饰器,以支持自定义不透明函数在 JAX 级垃圾代码消除 (DCE) 下的行为。有关更多详细信息,请参阅#25956。在
jax.lax中添加了低级归约 API:jax.lax.reduce_sum()、jax.lax.reduce_prod()、jax.lax.reduce_max()、jax.lax.reduce_min()、jax.lax.reduce_and()、jax.lax.reduce_or()和jax.lax.reduce_xor()。jax.lax.linalg.qr()和jax.scipy.linalg.qr()现在支持 CPU 和 GPU 上的列主元。请参阅#20282 和添加了
jax.random.multinomial()。有关更多详细信息,请参阅#25955。
更改
JAX_CPU_COLLECTIVES_IMPLEMENTATION和JAX_NUM_CPU_DEVICES现在可用作环境变量。在此之前,它们只能通过 jax.config 或标志指定。JAX_CPU_COLLECTIVES_IMPLEMENTATION现在默认为'gloo',这意味着多进程 CPU 通信开箱即用。TPU 附加包
jax[tpu]不再依赖于libtpu-nightly包。如果您的机器上存在此包,可以安全地移除;JAX 现在使用libtpu代替。
弃用
内部函数
linear_util.wrap_init和构造函数core.Jaxpr现在必须接受非空的core.DebugInfo关键字参数。在有限的时间内,如果使用jax.extend.linear_util.wrap_init而不带调试信息,则会打印DeprecationWarning。这导致了许多其他内部函数需要调试信息。此更改不影响公共 API。有关更多详细信息,请参阅 https://github.com/jax-ml/jax/issues/26480。在
jax.numpy.ndim()、jax.numpy.shape()和jax.numpy.size()中,现在已弃用非类数组输入(如列表、元组等)。
错误修复
TPU 运行时启动和关闭时间在 TPU v5e 及更新版本上应有显著改善(从约 17 秒到约 8 秒)。如果尚未设置,您可能需要在 VM 映像中启用透明大页(
sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled')。我们希望在未来的版本中进一步改进这一点。如果未设置 JAX_COMPILATION_CACHE_MAX_SIZE 或设置为 -1(即未启用 LRU 逐出策略),则持久化编译缓存不再写入访问时间文件。这应该能提高使用大容量网络存储的缓存的性能。
jax 0.5.0 (2025年1月17日)#
自此版本起,JAX 现在使用基于工作量的版本控制。由于此版本对 PRNG 密钥语义进行了可能需要用户更新其代码的重大更改,因此我们将 JAX 的“meso”版本号升级以表示这一点。
重大更改
默认启用
jax_threefry_partitionable(请参阅更新说明)。此版本放弃对 Mac x86 轮子的支持。Mac ARM 当然仍然受支持。有关最近的讨论,请参阅 https://github.com/jax-ml/jax/discussions/22936。
两个关键因素促成了这一决定:
Mac x86 构建(仅)存在许多测试失败和崩溃。我们宁愿不发布任何版本,也不发布一个损坏的版本。
Mac x86 硬件已停产,目前开发者难以获得。因此,即使我们想修复这类问题,也很难做到。
如果我们对 Mac x86 的支持社区愿意提供帮助,我们愿意重新添加对 Mac x86 的支持:特别是,在我们再次发布版本之前,我们需要 JAX 测试套件在 Mac x86 上干净地通过。
更改
最低 NumPy 版本现在是 1.25。NumPy 1.25 将保持最低支持版本,直到 2025 年 6 月。
最低 SciPy 版本现在是 1.11。SciPy 1.11 将保持最低支持版本,直到 2025 年 6 月。
jax.numpy.einsum()现在默认使用optimize='auto'而不是optimize='optimal'。这避免了在参数很多的情况下(#25214)指数级增长的跟踪时间。jax.numpy.linalg.solve()不再支持右侧的批处理一维参数。要在这些情况下恢复之前的行为,请使用solve(a, b[..., None]).squeeze(-1)。
新功能
jax.numpy.fft.fftn()、jax.numpy.fft.rfftn()、jax.numpy.fft.ifftn()和jax.numpy.fft.irfftn()现在支持超过 3 维的变换,而以前这是限制。有关更多详细信息,请参阅#25606。通过新的
jax.ffi.register_ffi_type_id()函数增加了对 FFI 中用户定义状态的支持。AOT 降低的
.as_text()方法现在支持debug_info选项,以在输出中包含调试信息,例如源位置。
弃用
从
jax.interpreters.xla中,abstractify和pytype_aval_mappings现在已弃用,已被jax.core中同名符号替换。jax.scipy.special.lpmn()和jax.scipy.special.lpmn_values()已被弃用,遵循其在 SciPy v1.15.0 中的弃用。没有计划用新 API 替换这些已弃用的函数。jax.extend.ffi子模块已移至jax.ffi,并且之前的导入路径已被弃用。
删除
jax_enable_memories标志已被删除,并且该标志的行为默认开启。从
jax.lib.xla_client中,先前已弃用的Device和XlaRuntimeError符号已移除;请分别使用jax.Device和jax.errors.JaxRuntimeError。jax.experimental.array_api模块已移除,此前已在 JAX v0.4.32 中弃用。自该版本以来,jax.numpy直接支持 Array API。
jax 0.4.38 (2024年12月17日)#
重大更改
XlaExecutable.cost_analysis现在返回一个dict[str, float](而不是一个单元素的list[dict[str, float]])。
更改
jax.tree.flatten_with_path和jax.tree.map_with_path已添加为相应tree_util函数的快捷方式。
弃用
内部
jax.core命名空间中的一些 API 已被弃用。大多数是无操作,很少使用,或者可以被jax.extend.core中同名 API 替换;有关这些半公开扩展的兼容性保证信息,请参阅jax.extend的文档。已移除几个先前已弃用的 API,包括
来自
jax.core:check_eqn、check_type、check_valid_jaxtype和non_negative_dim。来自
jax.lib.xla_bridge:xla_client和default_backend。来自
jax.lib.xla_client:_xla和bfloat16。来自
jax.numpy:round_。
新功能
jax.export.export()可用于使用jax.sharding.AbstractMesh()构建的分片进行设备多态导出。请参阅jax.export 文档。添加了
jax.lax.split()。这是jax.numpy.split()的一个原始版本,添加它是因为它在自动微分过程中产生更紧凑的转置。
jax 0.4.37 (2024年12月9日)#
这是 jax 0.4.36 的补丁版本。仅发布了“jax”此版本。
错误修复
修复了一个 bug,在该 bug 中
jit会在参数命名为f时报错(#25329)。修复了一个 bug,该 bug 在用户使用不同辅助数据注册 pytree 节点类用于 flatten 和 flatten_with_path 时,会在
jax.lax.while_loop()中抛出index out of range错误。固定了一个新的 libtpu 版本 (0.0.6),该版本修复了 TPU v6e 上的编译器 bug。
jax 0.4.36 (2024年12月5日)#
重大更改
此版本实现了“stackless”,这是 JAX 跟踪机制的一个内部更改。我们将跟踪分派完全改为上下文的函数,而不是上下文和数据的组合。这使我们能够删除大量用于管理数据相关跟踪的机制:级别、子级别、
post_process_call、new_base_main、custom_bind等等。如果您确实使用了 JAX 内部,那么您可能需要更新您的代码(有关如何操作的线索,请参阅 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f)。也可能存在 JAX 库版本不兼容的问题。如果您发现此更改破坏了不使用 JAX 内部的代码,请尝试使用
config.jax_data_dependent_tracing_fallback标志作为解决方法,如果您需要帮助更新代码,请提交 bug。使用
native_serialization=False或enable_xla=False的jax.experimental.jax2tf.convert()自 2024 年 7 月(JAX 版本 0.4.31)起已被弃用。现在我们移除了对这些用例的支持。带有原生序列化的jax2tf仍将受支持。在
jax.interpreters.xla中,xb、xc和xe符号已移除,此前已在 JAX v0.4.31 中弃用。而是使用xb = jax.lib.xla_bridge、xc = jax.lib.xla_client和xe = jax.lib.xla_extension。已弃用的模块
jax.experimental.export已被移除。它已在 JAX v0.4.30 中被jax.export替换。有关迁移到新 API 的信息,请参阅迁移指南。initial参数到jax.nn.softmax()和jax.nn.log_softmax()已被移除,此前已在 v0.4.27 中弃用。对类型化 PRNG 密钥(即
jax.random.key()生成的密钥)调用np.asarray现在会引发错误。之前,这会返回一个标量对象数组。已移除
jax.export中的以下已弃用方法和函数jax.export.DisabledSafetyCheck.shape_assertions:它已经没有效果了。jax.export.Exported.lowering_platforms:使用platforms。jax.export.Exported.mlir_module_serialization_version:使用calling_convention_version。jax.export.Exported.uses_shape_polymorphism:使用uses_global_constants。jax.export.export()的lowering_platforms关键字参数:改用platforms。
jax.export.symbolic_args_specs()中的关键字参数symbolic_scope和symbolic_constraints已被移除。它们已于 2024 年 6 月弃用。请改用scope和constraints。跟踪器的哈希处理,该处理自版本 0.4.30 起已被弃用,现在会导致
TypeError。重构:JAX 构建 CLI (build/build.py) 现在使用子命令结构,并取代了之前的 build.py 用法。运行
python build/build.py --help获取更多详细信息。新子命令选项的简要概述build:构建 JAX 轮子包。例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrtrequirements_update:更新 requirements_lock.txt 文件。
jax.scipy.linalg.toeplitz()现在对多维输入进行隐式批处理。要恢复之前的行为,您可以在函数输入上调用jax.numpy.ravel()。jax.scipy.special.gamma()和jax.scipy.special.gammasgn()现在对于负整数输入返回 NaN,以匹配 SciPy 的行为,源自 https://github.com/scipy/scipy/pull/21827。jax.clear_backends已被移除,此前已在 v0.4.26 中弃用。我们移除了自定义调用“__gpu$xla.gpu.triton”,将其从我们保证导出稳定性的自定义调用列表中移除。这是因为此自定义调用依赖于 Triton IR,而 Triton IR 不保证稳定。如果您需要导出使用此自定义调用的代码,可以使用
disabled_checks参数。有关更多详细信息,请参阅文档。
新功能
jax.jit()新增了一个compiler_options: dict[str, Any]参数,用于将编译选项传递给 XLA。目前该参数尚未文档化,并且可能会发生变化。jax.tree_util.register_dataclass()现在允许通过dataclasses.field()内联声明元数据字段。有关示例,请参阅函数文档。jax.lax.linalg.eig()和相关的jax.numpy函数(jax.numpy.linalg.eig()和jax.numpy.linalg.eigvals())现在已在 GPU 上支持。有关更多详细信息,请参阅 #24663。添加了两个新的配置标志
jax_exec_time_optimization_effort和jax_memory_fitting_effort,分别用于控制编译器在最小化执行时间和内存使用方面花费的精力。有效值为 -1.0 到 1.0,默认值为 0.0。
错误修复
修复了一个 bug,该 bug 导致 LU 和 QR 分解的 GPU 实现会导致批次大小接近 int32 最大值时发生索引溢出。有关更多详细信息,请参阅 #24843。
弃用
jax.lib.xla_extension.ArrayImpl和jax.lib.xla_client.ArrayImpl已弃用;请改用jax.Array。jax.lib.xla_extension.XlaRuntimeError已弃用;请改用jax.errors.JaxRuntimeError。
jax 0.4.35 (2024 年 10 月 22 日)#
重大更改
jax.numpy.isscalar()现在对于任何零维度的类数组对象都返回 True。以前,它只对具有弱 dtype 的零维类数组对象返回 True。jax.experimental.host_callback自 2024 年 3 月(JAX 版本 0.4.26)起已弃用。现在我们已将其移除。有关替代方案的讨论,请参阅 #20385。
更改
jax.lax.FftType被引入为 FFT 操作枚举的公共名称。半公共 APIjax.lib.xla_client.FftType已弃用。TPU:JAX 现在从
libtpu包而不是libtpu-nightly安装 TPU 支持。在接下来的几个版本中,JAX 将将libtpu-nightly和libtpu的版本固定为空,以简化过渡;此依赖项将在 2025 年第一季度移除。
弃用
半公共 API
jax.lib.xla_client.PaddingType已弃用。没有 JAX API 使用此类型,因此没有替代品。在
vmap下,jax.pure_callback()和jax.extend.ffi.ffi_call()的默认行为已弃用,这些函数的vectorized参数也已弃用。应改用vmap_method参数以获得更明确的行为。有关更多详细信息,请参阅 #23881 中的讨论。半公共 API
jax.lib.xla_client.register_custom_call_target已弃用。请改用 JAX FFI。半公共 API
jax.lib.xla_client.dtype_to_etype、jax.lib.xla_client.ops、jax.lib.xla_client.shape_from_pyval、jax.lib.xla_client.PrimitiveType、jax.lib.xla_client.Shape、jax.lib.xla_client.XlaBuilder和jax.lib.xla_client.XlaComputation已弃用。请改用 StableHLO。
jax 0.4.34 (2024 年 10 月 4 日)#
新功能
此版本包含 Python 3.13 的 wheel。尚不支持 free-threading 模式。
已添加
jax.errors.JaxRuntimeError作为以前私有的XlaRuntimeError类型的公共别名。
重大更改
标志
jax_pmap_no_rank_reduction默认设置为True。pmap 结果上的 array[0] 现在会引入一个 reshape(请改用 array[0:1])。
每个分片的形状(可通过 jax_array.addressable_shards 或 jax_array.addressable_data(0) 访问)现在具有一个前导 (1, …)。相应地更新直接访问分片的代码。每个分片形状的秩现在与全局形状的秩匹配,这与 jit 的行为相同。这避免了在将 pmap 的结果传递到 jit 时进行昂贵的 reshape。
jax.experimental.host_callback自 2024 年 3 月(JAX 版本 0.4.26)起已弃用。现在我们将--jax_host_callback_legacy配置值的默认值设置为True,这意味着如果您的代码使用jax.experimental.host_callbackAPI,那些 API 调用将基于新的jax.experimental.io_callbackAPI 实现。如果这破坏了您的代码,在非常短的时间内,您可以将--jax_host_callback_legacy设置为True。很快我们将移除该配置选项,因此您应该改用新的 JAX 回调 API。有关讨论,请参阅 #20385。
弃用
在
jax.numpy.trim_zeros()中,非类数组参数或ndim != 1的类数组参数现在已弃用,将来会导致错误。内部 pretty-printing 工具
jax.core.pp_*已移除,之前在 JAX v0.4.30 中已弃用。jax.lib.xla_client.Device已弃用;请改用jax.Device。jax.lib.xla_client.XlaRuntimeError已弃用。请改用jax.errors.JaxRuntimeError。
删除
jax.xla_computation已删除。距离其在 0.4.30 JAX 版本中弃用已有 3 个月。请使用 AOT API 来获取与jax.xla_computation相同的功能。jax.xla_computation(fn)(*args, **kwargs)可以被替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')。您还可以使用
jax.stages.Lowered的.out_info属性来获取输出信息(如树结构、形状和 dtype)。对于跨后端降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')。
jax.ShapeDtypeStruct不再接受named_shape参数。该参数仅由已移除的xmap使用(在 0.4.31 版本中移除)。jax.tree.map(f, None, non-None),之前会发出DeprecationWarning,现在在未来的 jax 版本中会引发错误。None仅是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map将None视为叶子值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)。jax.sharding.XLACompatibleSharding已移除。请使用jax.sharding.Sharding。
错误修复
修复了一个 bug,该 bug 导致
jax.numpy.cumsum()在提供非布尔输入并指定dtype=bool时产生不正确的结果。编辑
jax.numpy.ldexp()的实现以获得正确的梯度。
jax 0.4.33 (2024 年 9 月 16 日)#
这是在 jax 0.4.32 之上的一个补丁版本,修复了该版本中的两个 bug。
在 JAX 0.4.32 固定的 libtpu 版本中发现了一个仅限 TPU 的数据损坏 bug,该 bug 仅在同一作业中有多个 TPU 切片时出现,例如,在多个 v5e 切片上训练时。此版本通过固定一个特定版本的 libtpu 来修复此问题。
此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。
jax 0.4.32 (2024 年 9 月 11 日)#
注意:此版本因 TPU 数据损坏 bug 已从 PyPi 中移除。有关更多详细信息,请参阅 0.4.33 版本说明。
新功能
添加了
jax.extend.ffi.ffi_call()和jax.extend.ffi.ffi_lowering(),以支持使用新的 Foreign function interface (FFI) 从 JAX 与自定义 C++ 和 CUDA 代码进行交互。
更改
标志
jax_enable_memories默认设置为True。jax.numpy现在支持 Python Array API 标准的 v2023.12 版本。有关更多信息,请参阅 Python Array API standard。CPU 后端上的计算现在可以在更多情况下异步分派。以前,非并行计算始终是同步分派的。您可以通过设置
jax.config.update('jax_cpu_enable_async_dispatch', False)来恢复旧行为。添加了新的
jax.process_indices()函数,以替换 JAX v0.2.13 中已弃用的jax.host_ids()函数。为了与
numpy.fabs的行为保持一致,jax.numpy.fabs已修改为不再支持complex dtypes。jax.tree_util.register_dataclass现在会检查data_fields和meta_fields是否包含所有init=True的 dataclass 字段,并且仅包含这些字段,如果nodetype是 dataclass。一些
jax.numpy函数现在具有完整的ufunc接口,包括add、multiply、bitwise_and、bitwise_or、bitwise_xor、logical_and、logical_and和logical_and。在未来的版本中,我们计划将这些扩展到其他 ufunc。添加了
jax.lax.optimization_barrier(),它允许用户阻止编译器优化,例如公共子表达式消除和控制调度。
重大更改
MHLO MLIR 方言 (
jax.extend.mlir.mhlo) 已移除。请改用stablehlo方言。
弃用
不再允许对
jax.numpy.clip()和jax.numpy.hypot()输入复数,之前在 JAX v0.4.27 版本中已弃用。弃用了以下 API
jax.lib.xla_bridge.xla_client:请直接使用jax.lib.xla_client。jax.lib.xla_bridge.get_backend:请使用jax.extend.backend.get_backend()。jax.lib.xla_bridge.default_backend:请使用jax.extend.backend.default_backend()。
jax.experimental.array_api模块已弃用,不再需要导入它即可使用 Array API。jax.numpy直接支持 array API;有关更多信息,请参阅 Python Array API standard。内部实用工具
jax.core.check_eqn、jax.core.check_type和jax.core.check_valid_jaxtype现在已弃用,并将移除。jax.numpy.round_已弃用,遵循 NumPy 2.0 中移除相应 API 的情况。请改用jax.numpy.round()。将 DLPack capsule 传递给
jax.dlpack.from_dlpack()已弃用。jax.dlpack.from_dlpack()的参数应该是实现了__dlpack__协议的另一个框架的数组。
jaxlib 0.4.32 (2024 年 9 月 11 日)#
注意:此版本因 TPU 数据损坏 bug 已从 PyPi 中移除。有关更多详细信息,请参阅 0.4.33 版本说明。
重大更改
此 jaxlib 版本切换到了新的 CPU 后端版本,该版本编译速度更快,并且能更好地利用并行性。如果您遇到由于此更改导致的问题,可以通过设置环境变量
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false来暂时启用旧的 CPU 后端。如果您需要这样做,请提交一个 JAX bug 并提供重现说明。已添加 hermetic CUDA 支持。Hermetic CUDA 使用特定的可下载 CUDA 版本,而不是用户本地安装的 CUDA。Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后将 CUDA 库和工具用作各种 Bazel 目标中的依赖项。这使得 JAX 及其支持的 CUDA 版本的构建更加可复现。
更改
已添加 SparseCore 性能分析。
JAX 现在支持在 TPUv5p 芯片上对 SparseCore 进行性能分析。这些跟踪将在 Tensorboard Profiler 的 TraceViewer 中可见。
jax 0.4.31 (2024 年 7 月 29 日)#
删除
xmap 已删除。请使用
shard_map()作为替代。
更改
最低 CuDNN 版本为 v9.1。这在之前的版本中也适用,但我们现在正式声明此版本约束。
最低 Python 版本现在是 3.10。3.10 将在 2025 年 7 月之前保持最低支持版本。
最低 NumPy 版本现在是 1.24。NumPy 1.24 将在 2024 年 12 月之前保持最低支持版本。
最低 SciPy 版本现在是 1.10。SciPy 1.10 将在 2025 年 1 月之前保持最低支持版本。
jax.numpy.ceil()、jax.numpy.floor()和jax.numpy.trunc()现在返回与输入相同 dtype 的输出,即不再将整数或布尔输入提升为浮点数。libdevice.10.bc不再与 CUDA wheels 一起捆绑。它必须作为本地 CUDA 安装的一部分,或通过 NVIDIA 的 CUDA pip wheels 安装。jax.experimental.pallas.BlockSpec现在期望block_shape在index_map之前传递。旧的参数顺序已弃用,将在未来的版本中移除。更新了 GPU 设备的 repr,使其与 TPU/CPU 更一致。例如,
cuda(id=0)现在将是CudaDevice(id=0)。在
jax.Array中添加了device属性和to_device方法,这是 JAX 的 Array API 支持的一部分。
弃用
移除了许多先前已弃用的与多态形状相关的内部 API。从
jax.core:移除了canonicalize_shape、dimension_as_value、definitely_equal和symbolic_equal_dim。HLO 降低规则不应再将单例 ir.Values 包装在元组中。相反,返回未包装的单例 ir.Values。对包装值的支持将在未来的 JAX 版本中移除。
带有
native_serialization=False或enable_xla=False的jax.experimental.jax2tf.convert()已弃用,并且此支持将在未来的版本中移除。自 JAX 0.4.16(2023 年 9 月)以来,原生序列化一直是默认值。已移除先前已弃用的函数
jax.random.shuffle;请改用jax.random.permutation并设置independent=True。
jaxlib 0.4.31 (2024 年 7 月 29 日)#
错误修复
修复了一个 bug,该 bug 导致 jit 的快速路径错误处理了 jit 的负 static_argnums。
修复了一个 bug,该 bug 导致奇异矩阵批次的三角求解产生无意义的有限值,而不是 inf 或 nan (#3589, #15429)。
jax 0.4.30 (2024 年 6 月 18 日)#
更改
JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已升至 0.4.0,但在此版本中已回滚,以便同时使用 TensorFlow 和 JAX 的用户有更多时间迁移到更新的 TensorFlow 版本。
jax.experimental.mesh_utils现在可以为 TPU v5e 创建一个高效的 mesh。jax 现在直接依赖于 jaxlib。此更改由 CUDA 插件切换启用:不再存在多个 jaxlib 变体。您可以安装仅 CPU 的 jax,方法是
pip install jax,无需额外选项。添加了一个用于导出和序列化 JAX 函数的 API。这以前存在于
jax.experimental.export(目前正在弃用),现在将位于jax.export中。请参阅 文档。
弃用
内部 pretty-printing 工具
jax.core.pp_*已弃用,将在未来的版本中移除。tracer 的哈希处理已弃用,并且将在未来的 JAX 版本中导致
TypeError。以前是这种情况,但在最近几个 JAX 版本中出现了意外的回归。jax.experimental.export已弃用。请改用jax.export。请参阅 迁移指南。在大多数情况下,将数组作为 dtype 传递现在已弃用;例如,对于数组
x和y,x.astype(y)将引发警告。要消除警告,请使用x.astype(y.dtype)。jax.xla_computation已弃用,将在未来的版本中移除。请使用 AOT API 来获取与jax.xla_computation相同的功能。jax.xla_computation(fn)(*args, **kwargs)可以被替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')。您还可以使用
jax.stages.Lowered的.out_info属性来获取输出信息(如树结构、形状和 dtype)。对于跨后端降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')。
jaxlib 0.4.30 (2024 年 6 月 18 日)#
已放弃对单体 CUDA jaxlibs 的支持。您必须使用基于插件的安装(
pip install jax[cuda12]或pip install jax[cuda12_local])。
jax 0.4.29 (2024 年 6 月 10 日)#
更改
我们预计这将是最后一个支持单体 CUDA jaxlib 的 JAX 和 jaxlib 版本。未来的版本将使用 CUDA 插件 jaxlib(例如
pip install jax[cuda12])。JAX 现在需要 ml_dtypes 版本 0.4.0 或更高版本。
移除了对
jax.experimental.exportAPI 旧用法的向后兼容支持。现在无法使用from jax.experimental.export import export,而是应使用from jax.experimental import export。移除的功能自 0.4.24 版本起已弃用。在
jax.tree.all()&jax.tree_util.tree_all()中添加了is_leaf参数。
弃用
jax.sharding.XLACompatibleSharding已弃用。请使用jax.sharding.Sharding。jax.experimental.Exported.in_shardings已重命名为jax.experimental.Exported.in_shardings_hlo。对于out_shardings也是如此。旧名称将在 3 个月后移除。移除了许多先前已弃用的 API
jax.numpy.linalg.matrix_rank()的tol参数正在弃用,并将很快移除。请改用rtol。jax.numpy.linalg.pinv()的rcond参数正在弃用,并将很快移除。请改用rtol。已移除已弃用的
jax.config子模块。要配置 JAX,请使用import jax,然后通过jax.config引用配置对象。jax.randomAPI 不再接受批次密钥,以前有些是意外接受的。今后,我们建议在这种情况下显式使用jax.vmap()。在
jax.scipy.special.beta()中,x和y参数已重命名为a和b,以与其他betaAPI 保持一致。
新功能
添加了
jax.experimental.Exported.in_shardings_jax(),用于构造可与 JAX API 一起使用的分片,这些分片来自存储在Exported对象中的 HloShardings。
jaxlib 0.4.29 (2024 年 6 月 10 日)#
错误修复
修复了一个 bug,该 bug 导致 XLA 错误地对某些连接操作进行分片,表现为累积约简的输出不正确 (#21403)。
修复了一个 bug,该 bug 导致 XLA:CPU 错误地编译了某些 matmul fusion(https://github.com/openxla/xla/pull/13301)。
修复了 GPU 上的编译器崩溃问题(https://github.com/jax-ml/jax/issues/21396)。
弃用
jax.tree.map(f, None, non-None)现在会发出DeprecationWarning,并在未来的 jax 版本中引发错误。None仅是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map将None视为叶子值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)。
jax 0.4.28 (2024 年 5 月 9 日)#
错误修复
恢复了
make_jaxpr的一个更改,该更改曾破坏 Equinox (#21116)。
弃用和移除
jax.numpy.sort()和jax.numpy.argsort()的kind参数已被移除。请改用stable=True或stable=False。从
jax.experimental.pallas.gpu模块中移除了get_compute_capability。请改用jax.devices()或jax.local_devices()返回的 GPU 设备的compute_capability属性。jax.numpy.reshape()的newshape参数正在弃用,并将很快移除。请改用shape。
更改
此版本最低 jaxlib 版本为 0.4.27。
jaxlib 0.4.28 (2024 年 5 月 9 日)#
错误修复
修复了 Python 3.10 或更早版本中 Array 和 JIT Python 对象类型名称中的内存损坏 bug。
修复了 CUDA 12.4 下的警告
'+ptx84' is not a recognized feature for this target。修复了 CPU 上的慢速编译问题。
更改
Windows 构建现在使用 Clang 而不是 MSVC 构建。
jax 0.4.27 (2024 年 5 月 7 日)#
新功能
添加了
jax.numpy.unstack()和jax.numpy.cumulative_sum(),继它们在 array API 2023 标准中添加之后,该标准将很快被 NumPy 采纳。添加了一个新的配置选项
jax_cpu_collectives_implementation,用于选择 CPU 后端使用的跨进程集合操作的实现。可用的选项有'none'(默认)、'gloo'和'mpi'(需要 jaxlib 0.4.26)。如果设置为'none',则禁用跨进程集合操作。
更改
jax.pure_callback()、jax.experimental.io_callback()和jax.debug.callback()现在使用jax.Array而不是np.ndarray。您可以通过在将参数传递给回调之前使用jax.tree.map(np.asarray, args)来恢复旧行为。complex_arr.astype(bool)现在遵循与 NumPy 相同的语义,对于complex_arr等于0 + 0j的情况返回 False,否则返回 True。core.Token现在是一个非平凡类,它封装了一个jax.Array。它可以被创建并穿梭于计算之间以建立依赖关系。单例对象core.token已被移除,用户现在应该创建并使用新的core.Token对象。在 GPU 上,Threefry PRNG 实现默认不再降低到内核调用。这个选择可以提高运行时内存使用率,但会增加编译时间。可以恢复到旧行为(生成内核调用)的选项是
jax.config.update('jax_threefry_gpu_kernel_lowering', True)。如果新默认设置导致问题,请提交 bug。否则,我们打算在未来的版本中移除此标志。
弃用和移除
Pallas 现在仅使用 XLA 来编译 GPU 上的 Triton 内核。旧的通过 Triton Python API 进行降低的通道已移除,并且
JAX_TRITON_COMPILE_VIA_XLA环境变量不再有任何作用。jax.numpy.clip()的参数签名已更改:a、a_min和a_max已弃用,取而代之的是x(仅位置参数)、min和max(#20550)。JAX 数组的
device()方法已被移除,自 JAX v0.4.21 起已弃用。请改用arr.devices()。jax.nn.softmax()和jax.nn.log_softmax()的initial参数已弃用;空输入现在无需设置此参数即可使用 softmax。在
jax.jit()中,传递无效的static_argnums或static_argnames现在会导致错误而不是警告。最低 jaxlib 版本现在是 0.4.23。
jax.numpy.hypot()函数现在在将复数值输入传递给它时会发出弃用警告。当弃用完成后,这将引发一个错误。jax.numpy.nonzero()、jax.numpy.where()和相关函数的标量参数现在会引发错误,这遵循了 NumPy 中类似的更改。配置选项
jax_cpu_enable_gloo_collectives已弃用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')。jax.Array.device_buffer和jax.Array.device_buffers方法已被移除,之前在 JAX v0.4.22 中已弃用。请改用jax.Array.addressable_shards和jax.Array.addressable_data()。jax.numpy.where的condition、x和y参数现在是仅位置参数,遵循 JAX v0.4.21 中对关键字的弃用。jax.lax.linalg模块中的函数现在必须通过关键字指定非类数组参数。以前,这会引发 DeprecationWarning。多个
jax.numpy()API 现在需要类数组参数,包括apply_along_axis()、apply_over_axes()、inner()、outer()、cross()、kron()和lexsort()。
错误修复
jax.numpy.astype()现在当copy=True时将始终返回副本。以前,当输出数组的 dtype 与输入数组相同的情况下,不会进行复制。这可能会导致内存使用量增加。默认值为copy=False,以保持向后兼容。
jaxlib 0.4.27 (2024 年 5 月 7 日)#
jax 0.4.26 (2024 年 4 月 3 日)#
新功能
添加了
jax.numpy.trapezoid(),继 NumPy 2.0 中添加此函数之后。
更改
复数值
jax.numpy.geomspace()现在选择与 NumPy 2.0 一致的对数螺旋分支。lax.rng_bit_generator的行为,以及'rbg'和'unsafe_rbg'PRNG 实现的行为,在jax.vmap下发生了变化,导致映射密钥只从批次中的第一个密钥生成随机数(https://github.com/jax-ml/jax/issues/19085)。文档现在使用
jax.random.key来构建 PRNG 密钥数组,而不是jax.random.PRNGKey。
弃用和移除
jax.tree_map()已弃用;请改用jax.tree.map,或者为了与旧版 JAX 兼容,请使用jax.tree_util.tree_map()。jax.clear_backends()已弃用,因为它不一定与其名称所暗示的一样,并且可能导致意外后果,例如,它不会销毁现有后端并释放相应的拥有资源。如果您只想清理编译缓存,请使用jax.clear_caches()。为了向后兼容或您确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()。jax.experimental.maps模块和jax.experimental.maps.xmap已弃用。请使用jax.experimental.shard_map或jax.vmap和spmd_axis_name参数来表达 SPMD 设备并行计算。jax.experimental.host_callback模块已弃用。请改用 新的 JAX 外部回调。已添加JAX_HOST_CALLBACK_LEGACY标志以协助迁移到新的回调。有关讨论,请参阅 #20385。将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()和jax.numpy.array_equiv()现在会导致异常。已移除已弃用的标志
jax_parallel_functions_output_gda。此标志已长期弃用且不起作用;其使用是无操作。已移除先前已弃用的导入
jax.interpreters.ad.config和jax.interpreters.ad.source_info_util。请改用jax.config和jax.extend.source_info_util。JAX 不再支持旧的序列化版本。版本 9 自 2023 年 10 月 27 日起受支持,并自 2024 年 2 月 1 日起成为默认版本。请参阅 版本描述。此更改可能会破坏设置了低于 9 的特定 JAX 序列化版本的客户端。
jaxlib 0.4.26 (2024 年 4 月 3 日)#
更改
JAX 现在仅支持 CUDA 12.1 或更高版本。已停止支持 CUDA 11.8。
JAX 现在支持 NumPy 2.0。
jax 0.4.25 (2024 年 2 月 26 日)#
新功能
添加了 CUDA Array Interface 导入支持(需要 jaxlib 0.4.24)。
JAX 数组现在支持 NumPy 风格的标量布尔索引,例如
x[True]或x[False]。添加了
jax.tree模块,提供了更方便的接口来引用jax.tree_util中的函数。jax.tree.transpose()(即jax.tree_util.tree_transpose())现在接受inner_treedef=None,在这种情况下,内部 treedef 将自动推断。
更改
Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
JAX_TRITON_COMPILE_VIA_XLA环境变量设置为"0"来恢复旧行为。在
jax.interpreters.xla中,一些已弃用的 API(已在 v0.4.24 中移除)已在 v0.4.25 中重新添加,包括backend_specific_translations、translations、register_translation、xla_destructure、TranslationRule、TranslationContext和XLAOp。这些 API 仍被视为已弃用,将在未来提供更好的替代品时再次移除。有关讨论,请参阅 #19816。
弃用和移除
jax.numpy.linalg.solve()现在对批次 1D 解(b.ndim > 1)显示弃用警告。将来,这些将被视为批次 2D 解。将非标量数组转换为 Python 标量现在会引发错误,无论数组大小如何。以前,对于大小为 1 的非标量数组会引发弃用警告。这遵循了 NumPy 中类似的弃用。
已移除先前已弃用的配置 API,遵循标准的 3 个月弃用周期(请参阅 API compatibility)。这些包括
对象
jax.config.config和的
define_*_state和DEFINE_*方法jax.config。
通过
import jax.config导入jax.config子模块已弃用。要配置 JAX,请使用import jax,然后通过jax.config引用配置对象。最低 jaxlib 版本现在是 0.4.20。
jaxlib 0.4.25 (2024 年 2 月 26 日)#
jax 0.4.24 (2024 年 2 月 6 日)#
更改
JAX 降低到 StableHLO 不再依赖于物理设备。如果您的原语在降低规则中包装了 custom_partitioning 或 JAX 回调(即传递给
rule参数的函数,mlir.register_lowering),则将您的原语添加到jax._src.dispatch.prim_requires_devices_during_lowering集中。这是必需的,因为 custom_partitioning 和 JAX 回调在降低期间需要物理设备来创建Shardings。这是一个临时状态,直到我们可以创建不带物理设备的Shardings。jax.numpy.argsort()和jax.numpy.sort()现在支持stable和descending参数。形状多态处理的几项更改(用于
jax.experimental.jax2tf和jax.experimental.export)更清晰的符号表达式 pretty-printing(#19227)
添加了指定维度变量符号约束的能力。这使得形状多态更具表现力,并提供了一种规避不等式推理限制的方法。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
通过添加符号约束(#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名称。来自不同作用域的符号表达式不能交互,例如,在算术运算中。作用域由
jax.experimental.jax2tf.convert()、jax.experimental.export.symbolic_shape()、jax.experimental.export.symbolic_args_specs()引入。作用域e的符号表达式可以用e.scope读取,并传递给上述函数以指示它们在给定作用域中构建符号表达式。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。简化且更快的等式比较,我们认为两个符号维度相等,如果它们的差值经规范化后还原为 0(#19231;请注意,这可能会导致用户可见的行为更改)
改进了不确定不等式比较的错误消息(#19235)。
已弃用
core.non_negative_dimAPI(最近引入),并引入了core.max_dim和core.min_dim(#18953)来表示符号维度的max和min。您可以使用core.max_dim(d, 0)代替core.non_negative_dim(d)。shape_poly.is_poly_dim已弃用,改为使用export.is_symbolic_dim(#19282)。export.args_specs已弃用,改为使用export.symbolic_args_specs({jax-issue}#19283)。shape_poly.PolyShape和jax2tf.PolyShape已弃用,请使用字符串指定多态形状(#19284)。JAX 的默认原生序列化版本现为 9。这与
jax.experimental.jax2tf和jax.experimental.export相关。请参阅 版本号描述。
重构了
jax.experimental.export的 API。现在应使用from jax.experimental import export而不是from jax.experimental.export import export。旧的导入方式将在 3 个月的弃用期内继续有效。jax.numpy.unique()带有return_inverse = True时,返回的逆序索引会重塑为输入的维度,这与 NumPy 2.0 中的numpy.unique()的更改类似。jax.numpy.sign()现在对于非零复数输入返回x / abs(x)。这与 NumPy 2.0 版本中numpy.sign()的行为一致。jax.scipy.special.logsumexp()带有return_sign=True时,现在使用 NumPy 2.0 的复数符号约定x / abs(x)。这与 SciPy v1.13 中scipy.special.logsumexp()的行为一致。JAX 现在支持 DLPack 的 bool 类型进行导入和导出。以前 bool 值无法导入,并被导出为整数。
弃用和移除
许多先前已弃用的函数已被移除,遵循标准的 3 个月弃用周期(请参阅 API compatibility)。这包括
来自
jax.core:TracerArrayConversionError、TracerIntegerConversionError、UnexpectedTracerError、as_hashable_function、collections、dtypes、lu、map、namedtuple、partial、pp、ref、safe_zip、safe_map、source_info_util、total_ordering、traceback_util、tuple_delete、tuple_insert和zip。来自
jax.lax:dtypes、itertools、naryop、naryop_dtype_rule、standard_abstract_eval、standard_naryop、standard_primitive、standard_unop、unop和unop_dtype_rule。子模块
jax.linear_util及其所有内容。子模块
jax.prng及其所有内容。来自
jax.random:PRNGKeyArray、KeyArray、default_prng_impl、threefry_2x32、threefry2x32_key、threefry2x32_p、rbg_key和unsafe_rbg_key。来自
jax.tree_util:register_keypaths、AttributeKeyPathEntry和GetItemKeyPathEntry。来自
jax.interpreters.xla:backend_specific_translations,translations,register_translation,xla_destructure,TranslationRule,TranslationContext,axis_groups,ShapedArray,ConcreteArray,AxisEnv,backend_compile,以及XLAOp。来自
jax.numpy:NINF,NZERO,PZERO,row_stack,issubsctype,trapz,以及in1d。来自
jax.scipy.linalg:tril和triu。
先前已弃用的方法
PRNGKeyArray.unsafe_raw_array已被移除。请改用jax.random.key_data()。bool(empty_array)现在会引发错误而不是返回False。这先前会引发弃用警告,并且与 NumPy 中的类似更改一致。对 mhlo MLIR 方言的支持已弃用。JAX 不再使用 mhlo 方言,而是使用 stablehlo。引用“mhlo”的 API 将在未来被移除。请改用“stablehlo”方言。
jax.random:直接将批处理密钥传递给随机数生成函数,例如bits(),gamma()等,已被弃用并将发出FutureWarning。请使用jax.vmap进行显式批处理。jax.lax.tie_in()已弃用:自 JAX v0.2.0 起它一直是无操作。
jaxlib 0.4.24 (2024年2月6日)#
更改
JAX 现在支持 CUDA 12.3 和 CUDA 11.8。已停止支持 CUDA 12.2。
cost_analysis现在可以与交叉编译的Compiled对象一起工作(即在使用.lower().compile()并带有一个拓扑对象时,例如,在非 TPU 计算机上为 Cloud TPU 编译)。添加了 CUDA 数组接口 导入支持(需要 jax 0.4.25)。
jax 0.4.23 (2023年12月13日)#
jaxlib 0.4.23 (2023年12月13日)#
修复了一个导致 GPU 编译器在编译期间产生详细日志的错误。
jax 0.4.22 (2023年12月13日)#
弃用
JAX 数组的
device_buffer和device_buffers属性已弃用。显式缓冲区已被更灵活的数组分片接口取代,但仍可以通过这种方式恢复以前的输出。arr.device_buffer变为arr.addressable_data(0)arr.device_buffers变为[x.data for x in arr.addressable_shards]
jaxlib 0.4.22 (2023年12月13日)#
jax 0.4.21 (2023年12月4日)#
新功能
添加了
jax.nn.squareplus。
更改
最低 jaxlib 版本为 0.4.19。
发布的 wheel 现在使用 clang 而非 gcc 构建。
强制要求在调用
jax.distributed.initialize()之前设备后端尚未初始化。在 Cloud TPU 环境中自动配置
jax.distributed.initialize()的参数。
弃用
从
jax.scipy.linalg.solve()中移除了先前已弃用的sym_pos参数。请改用assume_a='pos'。将
None传递给jax.array()或jax.asarray()(直接或在列表或元组中)已被弃用,现在会引发FutureWarning。目前它会被转换为 NaN,将来会引发TypeError。为了与
numpy.where保持一致,将condition、x和y参数以关键字参数形式传递给jax.numpy.where已被弃用。将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()和jax.numpy.array_equiv()已被弃用,现在会引发DeprecationWaning。目前这些函数返回 False,将来会引发异常。JAX 数组的
device()方法已弃用。根据上下文,它可能被以下任一方法替换:jax.Array.devices()返回数组使用的所有设备集合。jax.Array.sharding提供数组使用的分片配置。
jaxlib 0.4.21 (2023年12月4日)#
更改
为了增加对分布式 CPU 的支持,JAX 现在将 CPU 设备与 GPU 和 TPU 设备同等对待,即:
jax.devices()包括分布式作业中存在的所有设备,即使不是当前进程本地的设备。jax.local_devices()仍然只包括当前进程本地的设备,因此如果jax.devices()的更改对您造成了影响,您最可能想要使用jax.local_devices()。CPU 设备现在会在分布式作业中获得一个全局唯一的 ID 号;先前 CPU 设备会获得一个进程本地的 ID 号。
每个 CPU 设备的
process_index现在将与同一进程中的任何 GPU 或 TPU 设备匹配;先前 CPU 设备的process_index始终是 0。
在 NVIDIA GPU 上,JAX 现在优先为最大 1024x1024 的矩阵选择 Jacobi SVD 求解器。Jacobi 求解器似乎比非 Jacobi 版本更快。
错误修复
修复了当非有限值的数组传递给非对称特征值分解时发生的错误/挂起(#18226)。具有非有限值的数组现在会产生全为 NaN 的数组作为输出。
jax 0.4.20 (2023年11月2日)#
jaxlib 0.4.20 (2023年11月2日)#
错误修复
修复了 E4M3 和 E5M2 浮点类型之间的某些类型混淆问题。
jax 0.4.19 (2023年10月19日)#
新功能
添加了
jax.typing.DTypeLike,可用于注解可转换为 JAX 数据类型的对象。添加了
jax.numpy.fill_diagonal。
更改
JAX 现在要求 SciPy 1.9 或更高版本。
错误修复
在多控制器分布式 JAX 程序中,只有 0 号进程会写入持久编译缓存条目。这修复了当缓存位于 GCS 等网络文件系统时发生的写入争用问题。
版本检查 for cusolver 和 cufft 不再考虑补丁版本,以确定安装版本是否不低于 JAX 构建所用的版本。
jaxlib 0.4.19 (2023年10月19日)#
更改
如果已安装 pip 安装的 NVIDIA CUDA 库(nvidia-... 包),jaxlib 将始终优先使用它们,而不是 LD_LIBRARY_PATH 中命名的任何其他 CUDA 安装。如果这引起问题,并且意图是使用系统安装的 CUDA,则解决方案是移除 pip 安装的 CUDA 库包。
jax 0.4.18 (2023年10月6日)#
jaxlib 0.4.18 (2023年10月6日)#
更改
CUDA jaxlibs 现在要求用户安装兼容的 NCCL 版本。如果使用推荐的
cuda12_pip安装,NCCL 应该会自动安装。目前需要 NCCL 2.16 或更高版本。我们现在提供 Linux aarch64 wheel,包括带 NVIDIA GPU 支持和不带 NVIDIA GPU 支持的。
jax.Array.item()现在支持可选的索引参数。
弃用
jax.lax中的一些内部实用程序和意外导出的内容已弃用,将在未来版本中移除。jax.lax.dtypes:请改用jax.dtypes。jax.lax.itertools:请改用itertools。naryop,naryop_dtype_rule,standard_abstract_eval,standard_naryop,standard_primitive,standard_unop,unop,以及unop_dtype_rule是内部实用程序,现已弃用且无替代。
错误修复
修复了 Cloud TPU 回归问题,该问题导致编译因 smem 而 OOM。
jax 0.4.17 (2023年10月3日)#
新功能
添加了新的
jax.numpy.bitwise_count()函数,其 API 与最近添加到 NumPy 的类似函数一致。
弃用
移除了已弃用的模块
jax.abstract_arrays及其所有内容。jax.random中的命名密钥构造函数已弃用。请改用向jax.random.PRNGKey()或jax.random.key()传递impl参数。random.threefry2x32_key(seed)变为random.PRNGKey(seed, impl='threefry2x32')random.rbg_key(seed)变为random.PRNGKey(seed, impl='rbg')random.unsafe_rbg_key(seed)变为random.PRNGKey(seed, impl='unsafe_rbg')
更改
CUDA:JAX 现在验证找到的 CUDA 库不旧于 JAX 构建所用的 CUDA 库。如果找到较旧的库,JAX 会引发异常,这比神秘的故障和崩溃更好。
移除了“未找到 GPU/TPU”警告。现在,如果在 Linux 上找到 NVIDIA GPU 或 Google TPU 但未使用,并且未指定
--jax_platforms,则会发出警告。jax.scipy.stats.mode()现在如果是在大小为 0 的轴上取众数,则返回 0 计数,这与 SciPy 1.11 中scipy.stats.mode的行为一致。大多数
jax.numpy函数和属性现在具有完全定义的类型存根。以前,许多这些都被静态类型检查器(如mypy和pytype)视为Any。
jaxlib 0.4.17 (2023年10月3日)#
更改
此版本添加了 Python 3.12 wheel。
CUDA 12 wheel 现在需要 CUDA 12.2 或更高版本以及 cuDNN 8.9.4 或更高版本。
错误修复
修复了初始化 JAX CPU 后端时 ABSL 的日志垃圾问题。
jax 0.4.16 (2023年9月18日)#
更改
添加了
jax.numpy.ufunc,以及jax.numpy.frompyfunc(),可以将任何标量值函数转换为numpy.ufunc()风格的对象,具有outer(),reduce(),accumulate(),at(),和reduceat()等方法(#17054)。在非 IPython 环境下运行时:当引发异常时,JAX 现在会从堆栈跟踪中过滤掉其内部帧。 (之前出现的“未过滤堆栈跟踪”已消失。)这应该会产生更友好的堆栈跟踪。有关示例,请参阅 此处。此行为可以通过设置
JAX_TRACEBACK_FILTERING=remove_frames(对于两个独立的未过滤/过滤的堆栈跟踪,这是旧的行为)或JAX_TRACEBACK_FILTERING=off(对于一个未过滤的堆栈跟踪)来更改。jax2tf 默认序列化版本现在是 7,这引入了新的形状 安全断言。
传递给
jax.sharding.Mesh的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices()已经是可哈希的。
重大更改
jax2tf 现在默认使用原生序列化。有关详细信息和覆盖默认值的机制,请参阅 jax2tf 文档。
选项
--jax_coordination_service已被移除。它现在始终为True。jax.jaxpr_util已从公共 JAX 命名空间中移除。JAX_USE_PJRT_C_API_ON_TPU不再生效(即,它始终默认为 true)。向后兼容标志
--jax_host_callback_ad_transforms(于 2021 年 12 月引入)已移除。
弃用
一些
jax.numpyAPI 已根据 NumPy NEP-52 进行弃用。jax.numpy.NINF已弃用。请改用-jax.numpy.inf。jax.numpy.PZERO已弃用。请改用0.0。jax.numpy.NZERO已弃用。请改用-0.0。jax.numpy.issubsctype(x, t)已弃用。请改用jax.numpy.issubdtype(x.dtype, t)。jax.numpy.row_stack已弃用。请改用jax.numpy.vstack。jax.numpy.in1d已弃用。请改用jax.numpy.isin。jax.numpy.trapz已弃用。请改用jax.scipy.integrate.trapezoid。
jax.scipy.linalg.tril和jax.scipy.linalg.triu已弃用,与 SciPy 一致。请改用jax.numpy.tril和jax.numpy.triu。jax.lax.prod已移除,此前在 JAX v0.4.11 中已弃用。请改用内置的math.prod。来自
jax.interpreters.xla中与为自定义 JAX 原始类型定义 HLO 降低规则相关的许多导出内容已弃用。自定义原始类型应使用jax.interpreters.mlir中的 StableHLO 降低实用程序来定义。以下先前已弃用的函数已在三个月后移除:
jax.abstract_arrays.ShapedArray:请改用jax.core.ShapedArray。jax.abstract_arrays.raise_to_shaped:请改用jax.core.raise_to_shaped。jax.numpy.alltrue:请改用jax.numpy.all。jax.numpy.sometrue:请改用jax.numpy.any。jax.numpy.product:请改用jax.numpy.prod。jax.numpy.cumproduct:请改用jax.numpy.cumprod。
弃用/移除
内部子模块
jax.prng已弃用。其内容可在jax.extend.random中找到。内部子模块路径
jax.linear_util已弃用。请改用jax.extend.linear_util。(属于 jax.extend:一个用于扩展的模块)jax.random.PRNGKeyArray和jax.random.KeyArray已弃用。请使用jax.Array进行类型注解,并使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)进行运行时检测 prng 密钥类型。方法
PRNGKeyArray.unsafe_raw_array已弃用。请改用jax.random.key_data()。jax.experimental.pjit.with_sharding_constraint已弃用。请改用jax.lax.with_sharding_constraint。内部实用程序
jax.core.is_opaque_dtype和jax.core.has_opaque_dtype已移除。不透明数据类型已重命名为扩展数据类型;请改用jnp.issubdtype(dtype, jax.dtypes.extended)(自 jax v0.4.14 起可用)。实用程序
jax.interpreters.xla.register_collective_primitive已移除。该实用程序在最近的 JAX 版本中没有实际作用,可以安全地移除对它的调用。内部子模块路径
jax.linear_util已弃用。请改用jax.extend.linear_util。(属于 jax.extend:一个用于扩展的模块)
jaxlib 0.4.16 (2023年9月18日)#
更改
通过实验性的 jax sparse API 进行的稀疏 CSR 矩阵乘法在 NVIDIA GPU 上不再使用确定性算法。此更改是为了提高与 CUDA 12.2.1 的兼容性。
错误修复
修复了 Windows 上由于 LLVM 致命错误(与乱序节和 IMAGE_REL_AMD64_ADDR32NB 重定位有关)导致的崩溃(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14 (2023年7月27日)#
更改
jax.jit接受donate_argnames作为参数。其语义类似于static_argnames。如果未提供 donate_argnums 或 donate_argnames,则不捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,或反之,JAX 将使用inspect.signature(fun)来查找与 donate_argnames 对应的任何位置参数(或反之)。如果同时提供了 donate_argnums 和 donate_argnames,则不使用 inspect.signature,并且只捐赠在 donate_argnums 或 donate_argnames 中列出的实际参数。jax.random.gamma()已重构为更高效的算法,并具有更健壮的端点行为(#16779)。这意味着对于给定的key,返回值的序列在 JAX v0.4.13 和 v0.4.14 之间对于gamma和相关采样器(包括jax.random.ball()、jax.random.beta()、jax.random.chisquare()、jax.random.dirichlet()、jax.random.generalized_normal()、jax.random.loggamma()、jax.random.t())会有所改变。
删除
自 3 个月前弃用以来,
in_axis_resources和out_axis_resources已从 pjit 中删除。请使用in_shardings和out_shardings作为替代。这是一个安全且简单的名称替换。它不会改变任何当前的 pjit 语义,也不会破坏任何代码。您仍然可以将PartitionSpecs传递给 in_shardings 和 out_shardings。
弃用
根据 https://jax.net.cn/en/latest/deprecation.html,已停止支持 Python 3.8。
根据 https://jax.net.cn/en/latest/deprecation.html,JAX 现在要求 NumPy 1.22 或更高版本。
在
jax.numpy.ndarray.at中按位置传递可选参数不再支持,此前已在 JAX 版本 0.4.7 中弃用。例如,应使用x.at[i].get(indices_are_sorted=True)而不是x.at[i].get(True)。以下
jax.Array方法已移除,此前已在 JAX v0.4.5 中弃用:jax.Array.broadcast:请改用jax.lax.broadcast()。jax.Array.broadcast_in_dim:请改用jax.lax.broadcast_in_dim()。jax.Array.split:请改用jax.numpy.split()。
以下 API 在先前弃用后已移除:
jax.ad:请改用jax.interpreters.ad。jax.curry:请改用curry = lambda f: partial(partial, f)。jax.partial_eval:请改用jax.interpreters.partial_eval。jax.pxla:请改用jax.interpreters.pxla。jax.xla:请改用jax.interpreters.xla。jax.ShapedArray:请改用jax.core.ShapedArray。jax.interpreters.pxla.device_put:请改用jax.device_put()。jax.interpreters.pxla.make_sharded_device_array:请改用jax.make_array_from_single_device_arrays()。jax.interpreters.pxla.ShardedDeviceArray:请改用jax.Array。jax.numpy.DeviceArray:请改用jax.Array。jax.stages.Compiled.compiler_ir:请改用jax.stages.Compiled.as_text()。
重大更改
JAX 现在要求 ml_dtypes 版本 0.2.0 或更高版本。
为了修复一个边缘情况,带有五个参数的
jax.lax.cond()调用将始终解析为“公共操作数”cond行为(如文档所述),即使其他操作数也为可调用。请参阅 #16413。已移除已弃用的配置选项
jax_array和jax_jit_pjit_api_merge,它们已无效。这些选项在许多版本中已默认为 true。
新功能
JAX 现在支持配置标志 –jax_serialization_version 和 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本(#16746)。
在存在形状多态性的情况下,jax2tf 现在生成检查某些形状约束的代码,前提是序列化版本至少为 7。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
jaxlib 0.4.14 (2023年7月27日)#
弃用
根据 https://jax.net.cn/en/latest/deprecation.html,已停止支持 Python 3.8。
jax 0.4.13 (2023年6月22日)#
更改
jax.jit现在允许将None传递给in_shardings和out_shardings。语义如下:对于 in_shardings,JAX 将将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
jax.experimental.pjit.pjit也允许将None传递给in_shardings和out_shardings。语义如下:如果未提供 mesh 上下文管理器,JAX 可以自由选择任何分片。
对于 in_shardings,JAX 将将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
如果提供了 mesh 上下文管理器,None 将意味着该值将在 mesh 的所有设备上复制。
Executable.cost_analysis() 在 Cloud TPU 上可用
添加了一个警告,如果使用了非白名单的
jaxlib插件。添加了
jax.tree_util.tree_leaves_with_path。None 不能作为
jax.experimental.multihost_utils.host_local_array_to_global_array或jax.experimental.multihost_utils.global_array_to_host_local_array的有效输入。如果您想复制输入,请使用jax.sharding.PartitionSpec()。
错误修复
修复了 CUDA 12 版本中的 wheel 名称不正确的问题(#16362);正确的 wheel 名称是
cudnn89而不是cudnn88。
弃用
对于
jax.experimental.jax2tf.convert(),参数native_serialization_strict_checks已弃用,取而代之的是新的native_serializaation_disabled_checks(#16347)。
jaxlib 0.4.13 (2023年6月22日)#
更改
在
jaxlibPypi 发布中添加了 Windows CPU-only wheel。
错误修复
__cuda_array_interface__在之前的 jaxlib 版本中存在问题,现已修复(#16440)。在 NVIDIA GPU 上,并发 CUDA 内核跟踪现已默认启用。
jax 0.4.12 (2023年6月8日)#
更改
添加了
scipy.spatial.transform.Rotation和scipy.spatial.transform.Slerp
弃用
jax.abstract_arrays及其内容已弃用。请参阅jax.core中的相关功能。jax.numpy.alltrue:请改用jax.numpy.all。这遵循 NumPy 版本 1.25.0 中numpy.alltrue的弃用。jax.numpy.sometrue:请改用jax.numpy.any。这遵循 NumPy 版本 1.25.0 中numpy.sometrue的弃用。jax.numpy.product:请改用jax.numpy.prod。这遵循 NumPy 版本 1.25.0 中numpy.product的弃用。jax.numpy.cumproduct:请改用jax.numpy.cumprod。这遵循 NumPy 版本 1.25.0 中numpy.cumproduct的弃用。jax.sharding.OpShardingSharding已移除,因其已弃用 3 个月。
jaxlib 0.4.12 (2023年6月8日)#
更改
包含 Hopper (SM version 9.0+) GPU 的 PTX/SASS。早期版本的 jaxlib 应该可以在 Hopper 上运行,但第一次执行 JAX 操作时会出现长时间的 JIT 编译延迟。
错误修复
修复了 Python 3.11 下 JAX 生成的 Python 堆栈跟踪中的源行信息不正确的问题。
修复了打印 JAX 生成的 Python 堆栈跟踪中的帧局部变量时发生的崩溃问题(#16027)。
jax 0.4.11 (2023年5月31日)#
弃用
根据 API 兼容性 策略,以下 API 在 3 个月后移除,遵循了弃用期:
jax.experimental.PartitionSpec:请改用jax.sharding.PartitionSpec。jax.experimental.maps.Mesh:请改用jax.sharding.Meshjax.experimental.pjit.NamedSharding:请改用jax.sharding.NamedSharding。jax.experimental.pjit.PartitionSpec:请改用jax.sharding.PartitionSpec。jax.experimental.pjit.FROM_GDA。请改用将分片jax.Array对象作为输入,并移除pjit的可选in_shardings参数。jax.interpreters.pxla.PartitionSpec:请改用jax.sharding.PartitionSpec。jax.interpreters.pxla.Mesh:请改用jax.sharding.Meshjax.interpreters.xla.Buffer:请改用jax.Array。jax.interpreters.xla.Device:请改用jax.Device。jax.interpreters.xla.DeviceArray:请改用jax.Array。jax.interpreters.xla.device_put:请改用jax.device_put。jax.interpreters.xla.xla_call_p:请改用jax.experimental.pjit.pjit_p。已移除
with_sharding_constraint的axis_resources参数。请改用shardings。
jaxlib 0.4.11 (2023年5月31日)#
更改
向
Devices 添加了memory_stats()方法。如果支持,它将返回一个字典,其中包含字符串统计名称和整数值,例如"bytes_in_use",或者如果平台不支持内存统计信息,则返回 None。确切的统计信息返回值可能因平台而异。目前仅在 Cloud TPU 上实现。重新添加了对 CPU 设备上的 Python 缓冲区协议(
memoryview)的支持。
jax 0.4.10 (2023年5月11日)#
jaxlib 0.4.10 (2023年5月11日)#
更改
修复了导致先前版本无法在 Mac M1 上运行的
'apple-m1' is not a recognized processor for this target (ignoring processor)问题。
jax 0.4.9 (2023年5月9日)#
更改
实验性标志 experimental_cpp_jit, experimental_cpp_pjit 和 experimental_cpp_pmap 已移除。它们现在始终开启。
TPU 上的奇异值分解 (SVD) 的精度已提高(需要 jaxlib 0.4.9)。
弃用
jax.experimental.gda_serialization已弃用,并已重命名为jax.experimental.array_serialization。请更改您的导入以使用jax.experimental.array_serialization。pjit 的
in_axis_resources和out_axis_resources参数已弃用。请分别使用in_shardings和out_shardings。函数
jax.numpy.msort已移除。它自 JAX v0.4.1 起已弃用。请改用jnp.sort(a, axis=0)。因其仅与 sharded_jit 一起使用,
in_parts和out_parts参数已从jax.xla_computation中移除,而 sharded_jit 已不存在。因长期未使用,
instantiate_const_outputs参数已从jax.xla_computation中移除。
jaxlib 0.4.9 (2023年5月9日)#
jax 0.4.8 (2023年3月29日)#
重大更改
Cloud TPU 运行时的一个主要组件已升级。这在 Cloud TPU 上启用了以下新功能:
jax.debug.print(),jax.debug.callback(), 和jax.debug.breakpoint()现在可以在 Cloud TPU 上工作。自动 TPU 内存碎片整理
使用新运行时组件时,不再支持
jax.experimental.host_callback()。如果新的jax.debugAPI 不足以满足您的用例,请在 JAX issue tracker 上提交问题。旧运行时组件至少在未来三个月内可用,方法是设置环境变量
JAX_USE_PJRT_C_API_ON_TPU=false。如果您发现出于任何原因需要禁用新运行时,请在 JAX issue tracker 上告知我们。
更改
最低 jaxlib 版本已从 0.4.6 提高到 0.4.7。
弃用
已停止支持 CUDA 11.4。JAX GPU wheel 仅支持 CUDA 11.8 和 CUDA 12。旧版 CUDA 可能可用,前提是 jaxlib 是从源代码构建的。
global_arg_shapes参数的 pmap 仅适用于 sharded_jit,已从 pmap 中移除。请迁移到 pjit 并从 pmap 中移除 global_arg_shapes。
jax 0.4.7 (2023年3月27日)#
更改
根据 https://jax.net.cn/en/latest/jax_array_migration.html#jax-array-migration,
jax.config.jax_array已无法禁用。jax.config.jax_jit_pjit_api_merge已无法禁用。jax.experimental.jax2tf.convert()现在支持native_serialization参数,以使用 JAX 的原生降低到 StableHLO 来获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原始类型降低到 TensorFlow op。这简化了内部结构,并增加了对所序列化内容与 JAX 原生语义匹配的信心。请参阅 文档。作为此更改的一部分,配置标志--jax2tf_default_experimental_native_lowering已重命名为--jax2tf_native_serialization。JAX 现在依赖
ml_dtypes,其中包含 bfloat16 等 NumPy 类型的定义。这些定义以前是 JAX 的内部内容,但已拆分到一个单独的包中,以便与其他项目共享。JAX 现在要求 NumPy 1.21 或更新版本,以及 SciPy 1.7 或更新版本。
弃用
类型
jax.numpy.DeviceArray已弃用。请改用jax.Array,它是其别名。类型
jax.interpreters.pxla.ShardedDeviceArray已弃用。请改用jax.Array。在
jax.numpy.ndarray.at中按位置传递附加参数已弃用。例如,应使用x.at[i].get(indices_are_sorted=True)而不是x.at[i].get(True)。jax.interpreters.xla.device_put已弃用。请使用jax.device_put。jax.interpreters.pxla.device_put已弃用。请使用jax.device_put。jax.experimental.pjit.FROM_GDA已弃用。请将分片 jax.Arrays 作为输入传递,并移除 pjit 的in_shardings参数,因为它不是必需的。
jaxlib 0.4.7 (2023年3月27日)#
更改
jaxlib 现在依赖
ml_dtypes,其中包含 bfloat16 等 NumPy 类型的定义。这些定义以前是 JAX 的内部内容,但已拆分到一个单独的包中,以便与其他项目共享。
jax 0.4.6 (2023年3月9日)#
更改
jax.tree_util现在包含一组 API,允许用户定义自定义 Pytree 节点的键。这包括:tree_flatten_with_path,它扁平化一个树,并返回每个叶子及其键路径。tree_map_with_path,它可以映射一个接受键路径作为参数的函数。register_pytree_with_keys,用于注册自定义 Pytree 节点的键路径和叶子应如何显示。keystr,用于美观地打印键路径。
jax2tf.call_tf()有一个新参数output_shape_dtype(默认为None),可用于声明结果的输出形状和类型。这使得jax2tf.call_tf()能够在形状多态性的情况下工作。(#14734)。
弃用
jax.tree_util中的旧键路径 API 已弃用,并将于 2023 年 3 月 10 日起三个月内移除。register_keypaths:请改用jax.tree_util.register_pytree_with_keys()。AttributeKeyPathEntry:请改用GetAttrKey。GetitemKeyPathEntry:请改用SequenceKey或DictKey。
jaxlib 0.4.6 (2023年3月9日)#
jax 0.4.5 (2023年3月2日)#
弃用
jax.sharding.OpShardingSharding已重命名为jax.sharding.GSPMDSharding。jax.sharding.OpShardingSharding将在 2023 年 2 月 17 日起三个月后移除。以下
jax.Array方法已弃用,并将于 2023 年 2 月 23 日起三个月后移除:jax.Array.broadcast:请改用jax.lax.broadcast()。jax.Array.broadcast_in_dim:请改用jax.lax.broadcast_in_dim()。jax.Array.split:请改用jax.numpy.split()。
jax 0.4.4 (2023年2月16日)#
更改
jit 和 pjit 的实现已合并。合并 jit 和 pjit 更改了 JAX 的内部结构,而不影响 JAX 的公共 API。以前,jit 是一个最终风格的原始类型。最终风格意味着 jaxpr 的创建被尽可能延迟,转换堆叠在一起。随着 jit-pjit 实现的合并,jit 成为一个初始风格的原始类型,这意味着我们尽早将转换到 jaxpr。有关更多信息,请参阅 autodidax 中的 此部分。切换到初始风格应该会简化 JAX 的内部结构,并使动态形状等功能的开发更加容易。您只能通过环境变量禁用它,即
os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。由于合并在导入时影响 JAX,因此必须在导入 jax 之前通过环境变量禁用它。已弃用
with_sharding_constraint的axis_resources参数。请改用shardings。如果您将axis_resources用作参数,则无需更改。如果您将其用作关键字参数,请改用shardings。axis_resources将在 2023 年 2 月 13 日起三个月后移除。添加了
jax.typing模块,其中包含用于 JAX 函数类型注解的工具。以下名称已弃用:
jax.xla.Device和jax.interpreters.xla.Device:请改用jax.Device。jax.experimental.maps.Mesh。请改用jax.sharding.Mesh。jax.experimental.pjit.NamedSharding:请改用jax.sharding.NamedSharding。jax.experimental.pjit.PartitionSpec:请改用jax.sharding.PartitionSpec。jax.interpreters.pxla.Mesh:请改用jax.sharding.Mesh。jax.interpreters.pxla.PartitionSpec:请改用jax.sharding.PartitionSpec。
重大更改
与
jax.numpy.sum()等约简函数中的initial参数现在被要求为标量,这与相应的 NumPy API 一致。先前将非标量initial值广播到输出的行为是一个非故意的实现细节(#14446)。
jaxlib 0.4.4 (2023年2月16日)#
重大更改
默认
jaxlib构建已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以从源代码构建jaxlib(通过build.py的--cuda_compute_capabilities=sm_35选项),但请注意,CUDA 12 已完全停止支持 Kepler GPU。
jax 0.4.3 (2023年2月8日)#
重大更改
已删除
jax.scipy.linalg.polar_unitary(),这是一个已弃用的 JAX 对 scipy API 的扩展。请改用jax.scipy.linalg.polar()。
更改
jaxlib 0.4.3 (2023年2月8日)#
jax.Array现在具有非阻塞is_ready()方法,该方法返回True如果数组已准备就绪(另请参见jax.block_until_ready())。
jax 0.4.2 (2023年1月24日)#
重大更改
删除了
jax.experimental.callback。涉及 jax2tf 形状多态性的维度操作已推广到更多场景,方法是将符号维度转换为 JAX 数组。涉及符号维度和
np.ndarray的操作现在可能在将结果用作形状值时引发错误(#14106)。jaxpr 对象现在对属性设置引发错误,以避免问题性突变(#14102)。
更改
jax2tf.call_tf()有一个新参数has_side_effects(默认为True),可用于声明实例是否可以被 JAX 优化(如死代码消除)移除或复制(#13980)。为 jax2tf 形状多态性添加了对 floordiv 和 mod 的更多支持。先前,某些除法操作在存在符号维度时会导致错误(#14108)。
jaxlib 0.4.2 (2023年1月24日)#
更改
将 JAX_USE_PJRT_C_API_ON_TPU=1 设置为启用新的 Cloud TPU 运行时,该运行时具有自动设备内存碎片整理功能。
jax 0.4.1 (2022年12月13日)#
更改
根据 JAX 的 Python 和 NumPy 版本支持策略,已停止对 Python 3.7 的支持。
我们引入了
jax.Array,它是一种统一的数组类型,包含了 JAX 中的DeviceArray、ShardedDeviceArray和GlobalDeviceArray类型。jax.Array类型有助于使并行化成为 JAX 的核心功能,简化和统一 JAX 内部结构,并允许我们统一jit和pjit。jax.Array在 JAX 0.4 中已默认启用,并对pjitAPI 造成了一些破坏性更改。jax.Array 迁移指南 可以帮助您将代码库迁移到jax.Array。您还可以参考 分布式数组和自动并行化 教程来理解新概念。PartitionSpec和Mesh现在已不再是实验性的。新的 API 端点是jax.sharding.PartitionSpec和jax.sharding.Mesh。jax.experimental.maps.Mesh和jax.experimental.PartitionSpec已弃用,将在 3 个月后移除。with_sharding_constraint的新公共端点是jax.lax.with_sharding_constraint。如果将 ABSL 标志与
jax.config一起使用,则在 JAX 配置选项首次从 ABSL 标志加载后,ABSL 标志的值不再读取或写入。此更改提高了读取jax.config选项的性能,因为jax.config选项在 JAX 中使用非常广泛。jax2tf 现在使用与嵌入式 JAX 计算相同平台的第一个 TF 设备进行 TF 降低。之前,它为 JAX 默认后端使用第 0 个设备。
许多
jax.numpy函数现在已将其参数标记为仅限位置参数,与 NumPy 一致。jnp.msort已弃用,遵循 numpy 1.24 中np.msort的弃用。它将在未来的版本中移除,遵循 API 兼容性 策略。它可以替换为jnp.sort(a, axis=0)。
jaxlib 0.4.1 (2022年12月13日)#
更改
根据 JAX 的 Python 和 NumPy 版本支持策略,已停止对 Python 3.7 的支持。
更改了
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX的行为,以分配总 GPU 内存的 XX% 来代替先前使用当前可用 GPU 内存计算预分配的行为。有关详细信息,请参阅 GPU 内存分配。已移除已弃用的方法
.block_host_until_ready()。请改用.block_until_ready()。
jax 0.4.0 (2022年12月12日)#
此版本已被撤回。
jaxlib 0.4.0 (2022年12月12日)#
此版本已被撤回。
jax 0.3.25 (2022年11月15日)#
更改
jax.numpy.linalg.pinv()现在支持hermitian选项。jax.scipy.linalg.hessenberg()现在仅在 CPU 上支持。需要 jaxlib > 0.3.24。添加了新函数
jax.lax.linalg.hessenberg()、jax.lax.linalg.tridiagonal()和jax.lax.linalg.householder_product()。Householder 约简目前仅限 CPU,而三对角线约简仅支持 CPU 和 GPU。非方阵的
svd和jax.numpy.linalg.pinv的梯度现在计算更经济。
重大更改
删除了
jax_experimental_name_stack配置选项。将
jax.experimental.maps.Mesh构造函数的字符串axis_names参数转换为单例元组,而不是将字符串解包为字符轴名称序列。
jaxlib 0.3.25 (2022年11月15日)#
更改
增加了对 CPU 和 GPU 上三对角线约简的支持。
增加了对 CPU 上上三对角线约简的支持。
错误
修复了一个错误,该错误导致 JAX 捕获的堆栈跟踪中的帧在 Python 3.10+ 下与源行映射不正确。
jax 0.3.24 (2022年11月4日)#
更改
JAX 的导入速度应该更快了。我们现在惰性导入 scipy,这占了 JAX 导入时间的很大一部分。
设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N可用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间为 1 秒或更长的计算将被缓存。默认设备顺序由
pmap在 TPU 上使用(如果未指定顺序)现在匹配单进程作业的jax.devices()。先前这两个顺序不同,可能导致不必要的复制或内存不足错误。要求顺序一致简化了问题。
重大更改
jax.numpy.gradient()现在与jax.numpy中的大多数其他函数行为一致,禁止在需要数组的地方传入列表或元组(#12958)。jax.numpy.linalg和jax.numpy.fft中的函数现在统一要求输入为类数组(array-like):即不能用列表和元组代替数组。这是 #7737 的一部分。
弃用
jax.sharding.MeshPspecSharding已重命名为jax.sharding.NamedSharding。jax.sharding.MeshPspecSharding名称将在 3 个月后移除。
jaxlib 0.3.24 (2022 年 11 月 4 日)#
更改
Buffer donation 现在可以在 CPU 上使用。这可能会破坏已将 buffer 标记为 donation 但依赖于 donation 未实现的代码。
jax 0.3.23 (2022 年 10 月 12 日)#
更改
为新的 jaxlib 发布更新 Colab TPU 驱动程序版本。
jax 0.3.22 (2022 年 10 月 11 日)#
更改
在 TPU 初始化时,默认设置
JAX_PLATFORMS=tpu,cpu,因此如果 TPU 初始化失败,JAX 将引发错误而不是回退到 CPU。设置JAX_PLATFORMS=''可覆盖此行为并自动选择可用后端(原始默认值),或设置JAX_PLATFORMS=cpu以始终使用 CPU,无论 TPU 是否可用。
弃用
JAX v0.3.8 中弃用的几个测试工具已从
jax.test_util中移除。
jaxlib 0.3.22 (2022 年 10 月 11 日)#
jax 0.3.21 (2022 年 9 月 30 日)#
更改
持久化编译缓存现在会在出错时发出警告而不是引发异常(#12582),因此如果缓存出现问题,程序可以继续执行。设置
JAX_RAISE_PERSISTENT_CACHE_ERRORS=true可恢复此行为。
jax 0.3.20 (2022 年 9 月 28 日)#
jaxlib 0.3.20 (2022 年 9 月 28 日)#
错误修复
修复了在分布式作业中通过
jax_cuda_visible_devices限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成是必需的(#12533)。
jax 0.3.19 (2022 年 9 月 27 日)#
修复了必需的 jaxlib 版本。
jax 0.3.18 (2022 年 9 月 26 日)#
更改
提前(Ahead-of-time)降低和编译功能(在 #7733 中跟踪)现在稳定且公开。请参阅 概述 和
jax.stages的 API 文档。引入了
jax.Array,用于isinstance检查和 JAX 中数组类型的类型注解。请注意,这包括了jax.numpy.ndarray对于 jax 内部对象的 `isinstance` 工作方式的一些细微变化,因为jax.numpy.ndarray现在是jax.Array的简单别名。
重大更改
jax._src不再导入到公共jax命名空间。这可能会破坏使用 JAX 内部功能的用户的代码。jax.soft_pmap已删除。请使用pjit或xmap代替。jax.soft_pmap未被文档化。如果它被文档化,会提供一个弃用期。
jax 0.3.17 (2022 年 8 月 31 日)#
错误
修复了
lax.pow指数为零时梯度的边缘情况问题(#12041)。
重大更改
jax.checkpoint()(也称为jax.remat())不再支持concrete选项,遵循先前版本的弃用;请参阅 JEP 11830。
更改
添加了
jax.pure_callback(),它允许从编译后的函数(例如,用jax.jit或jax.pmap装饰的函数)调用纯 Python 函数。
弃用
已弃用的
DeviceArray.tile()方法已被移除。请使用jax.numpy.tile()(#11944)。DeviceArray.to_py()已被弃用。请使用np.asarray(x)代替。
jax 0.3.16#
重大更改
根据 弃用策略,已停止支持 NumPy 1.19。请升级到 NumPy 1.20 或更高版本。
更改
添加了
jax.debug,其中包含用于运行时值调试的实用程序,如jax.debug.print()和jax.debug.breakpoint()。添加了关于 运行时值调试 的新文档。
弃用
已移除
jax.mask()和jax.shapecheck()API。请参阅 #11557。jax.experimental.loops已移除。请参阅 #10278 获取替代 API。jax.tree_util.tree_multimap()已移除。它自 JAX release 0.3.5 起已被弃用,而jax.tree_util.tree_map()是其直接替代品。移除了
jax.experimental.stax;它长期以来一直是jax.example_libraries.stax的一个已弃用别名。移除了
jax.experimental.optimizers;它长期以来一直是jax.example_libraries.optimizers的一个已弃用别名。jax.checkpoint()(也称为jax.remat())有一个新的默认启用的实现,这意味着旧的实现已被弃用;请参阅 JEP 11830。
jax 0.3.15 (2022 年 7 月 22 日)#
更改
JaxTestCase和JaxTestLoader已从jax.test_util中移除。这些类自 v0.3.1 起已被弃用(#11248)。添加了
jax.scipy.gaussian_kde(#11237)。JAX 数组与内置集合(
dict、list、set、tuple)之间的二进制操作现在在所有情况下都会引发TypeError。之前有些情况(特别是相等和不等)会返回与 NumPy 中类似操作不一致的布尔标量(#11234)。通过顶层 JAX 包导入访问的几个
jax.tree_util例程已被弃用,并将根据 API 兼容性 策略在未来的 JAX 版本中移除。jax.treedef_is_leaf()已弃用,建议使用jax.tree_util.treedef_is_leaf()。jax.tree_flatten()已弃用,建议使用jax.tree_util.tree_flatten()。jax.tree_leaves()已弃用,建议使用jax.tree_util.tree_leaves()。jax.tree_structure()已弃用,建议使用jax.tree_util.tree_structure()。jax.tree_transpose()已弃用,建议使用jax.tree_util.tree_transpose()。jax.tree_unflatten()已弃用,建议使用jax.tree_util.tree_unflatten()。
jax.scipy.linalg.solve()的sym_pos参数已弃用,建议使用assume_a='pos',这与scipy.linalg.solve()的弃用类似。
jaxlib 0.3.15 (2022 年 7 月 22 日)#
jax 0.3.14 (2022 年 6 月 27 日)#
重大更改
jax.experimental.compilation_cache.initialize_cache()不再支持max_cache_size_ bytes,也不会将其作为输入。JAX_PLATFORMS现在在平台初始化失败时会引发异常。
更改
修复了与 NumPy 1.23 的兼容性问题。
jax.numpy.linalg.slogdet()现在接受一个可选的method参数,允许在基于 LU 分解的实现和基于 QR 分解的实现之间进行选择。jax.numpy.linalg.qr()现在支持mode="raw"。pickle、copy.copy和copy.deepcopy现在对 JAX 数组的使用提供了更完整的支持(#10659)。具体来说:pickle和deepcopy之前在使用DeviceArray时会返回np.ndarray对象;现在会返回DeviceArray对象。对于deepcopy,复制的数组在与原始数组相同的设备上。对于pickle,反序列化的数组将位于默认设备上。在函数转换(即跟踪代码)中,
deepcopy和copy之前是空操作。现在它们使用与DeviceArray.copy()相同的机制。对跟踪数组调用
pickle现在会导致显式的ConcretizationTypeError。
在 TPU 上,奇异值分解(SVD)和对称/厄米特征值分解的实现应该会显著加快,特别是对于大于 1000x1000 的矩阵。两者现在都使用谱分治算法进行特征值分解(QDWH-eig)。
jax.numpy.ldexp()不再默默地将所有输入提升为 float64,而是将 int32 或更小大小的整数输入提升为 float32(#10921)。向
jax.profiler.start_trace()和jax.profiler.start_trace()添加了create_perfetto_link选项。使用时,profiler 将生成一个指向 Perfetto UI 的链接以查看跟踪。更改了
jax.profiler.start_server(...)()的语义,以将 keepalive 全局存储,而不是要求用户保留其引用。添加了
jax.random.ball()。添加了
jax.default_device()。添加了一个
python -m jax.collect_profile脚本,用于手动捕获程序跟踪,作为 TensorBoard UI 的替代方法。添加了一个
jax.named_scope上下文管理器,用于将 profiler 元数据添加到 Python 程序(类似于jax.named_call)。在 scatter-update 操作(即
jax.numpy.ndarray.at)中,不安全的隐式 dtype 转换已被弃用,现在会导致FutureWarning。在未来的版本中,这将成为一个错误。一个不安全隐式转换的例子是jnp.zeros(4, dtype=int).at[0].set(1.5),其中1.5之前会被静默截断为1。jax.experimental.compilation_cache.initialize_cache()现在支持 gcs bucket 路径作为输入。jax.numpy.roots()在strip_zeros=False且系数前导零时行为更好(#11215)。
jaxlib 0.3.14 (2022 年 6 月 27 日)#
-
x86-64 Mac wheel 现在要求 Mac OS 10.14 (Mojave) 或更新版本。Mac OS 10.14 发布于 2018 年,因此这应该不是一个很苛刻的要求。
更新了 NCCL 的捆绑版本到 2.12.12,修复了一些死锁问题。
Python flatbuffers 包不再是 jaxlib 的依赖项。
jax 0.3.13 (2022 年 5 月 16 日)#
jax 0.3.12 (2022 年 5 月 15 日)#
更改
修复了 #10717。
jax 0.3.11 (2022 年 5 月 15 日)#
更改
jax.lax.eigh()现在接受一个可选的sort_eigenvalues参数,允许用户选择在 TPU 上不进行特征值排序。
弃用
jax.lax.linalg中的函数现在将非数组参数标记为仅关键字参数。作为向后兼容步骤,通过位置传递关键字参数会产生警告,但在未来的 JAX 版本中,通过位置传递关键字参数将失败。然而,大多数用户应该更倾向于使用jax.numpy.linalg。JAX 扩展到 scipy API 的
jax.scipy.linalg.polar_unitary()已被弃用。请使用jax.scipy.linalg.polar()代替。
jax 0.3.10 (2022 年 5 月 3 日)#
jaxlib 0.3.10 (2022 年 5 月 3 日)#
更改
TF commit 修复了 MHLO 规范化器中的一个问题,该问题导致某些程序的常量折叠花费很长时间或崩溃。
jax 0.3.9 (2022 年 5 月 2 日)#
更改
为 GlobalDeviceArray 添加了对完全异步 checkpointing 的支持。
jax 0.3.8 (2022 年 4 月 29 日)#
更改
TPU 上的
jax.numpy.linalg.svd()使用 qdwh-svd 求解器。TPU 上的
jax.numpy.linalg.cond()现在接受复数输入。TPU 上的
jax.numpy.linalg.pinv()现在接受复数输入。TPU 上的
jax.numpy.linalg.matrix_rank()现在接受复数输入。jax.experimental.maps.mesh已删除。请使用jax.experimental.maps.Mesh。请参阅 https://jax.net.cn/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh 获取更多信息。jax.scipy.linalg.qr()现在当mode='r'时返回一个长度为 1 的元组而不是原始数组,以匹配scipy.linalg.qr的行为(#10452)。jax.numpy.take_along_axis()现在接受一个可选的mode参数,该参数指定了越界索引的行为。默认情况下,对于越界索引将返回无效值(例如,NaN)。在早期版本的 JAX 中,无效索引会被裁剪到范围内。通过传递mode="clip"可以恢复之前的行为。jax.numpy.take()现在默认使用mode="fill",它为越界索引返回无效值(例如,NaN)。Scatter 操作,如
x.at[...].set(...),现在具有"drop"语义。这对 scatter 操作本身没有影响,但这意味着当进行微分时,scatter 的梯度将为越界索引返回零余切。之前越界索引会被裁剪到范围内以进行梯度计算,这在数学上是不正确的。jax.numpy.take_along_axis()现在如果其索引不是整数类型,则会引发TypeError,这与numpy.take_along_axis()的行为一致。之前非整数索引会被静默转换为整数。jax.numpy.ravel_multi_index()现在如果其dims参数不是整数类型,则会引发TypeError,这与numpy.ravel_multi_index()的行为一致。之前非整数dims会被静默转换为整数。jax.numpy.split()现在如果其axis参数不是整数类型,则会引发TypeError,这与numpy.split()的行为一致。之前非整数axis会被静默转换为整数。jax.numpy.indices()现在如果其维度不是整数类型,则会引发TypeError,这与numpy.indices()的行为一致。之前非整数维度会被静默转换为整数。jax.numpy.diag()现在如果其k参数不是整数类型,则会引发TypeError,这与numpy.diag()的行为一致。之前非整数k会被静默转换为整数。
弃用
在
jax.test_util中可用的许多函数和对象现在已被弃用,并在导入时发出警告。这包括cases_from_list、check_close、check_eq、device_under_test、format_shape_dtype_string、rand_uniform、skip_on_devices、with_config、xla_bridge和_default_tolerance(#10389)。这些类,连同先前已弃用的JaxTestCase、JaxTestLoader和BufferDonationTestCase,将在未来的 JAX 版本中移除。这些实用程序中的大多数可以被标准 Python 和 NumPy 测试实用程序替换,例如unittest、absl.testing、numpy.testing等。 JAX 特定的功能,如设备检查,可以通过使用公共 API(如jax.devices())来替换。许多已弃用的实用程序仍将存在于jax._src.test_util中,但这些不是公共 API,因此在未来的版本中可能会被更改或移除,恕不另行通知。
jax 0.3.7 (2022 年 4 月 15 日)#
更改
修复了当传递给
jax.numpy.take_along_axis()的索引被广播时出现的性能问题(#10281)。jax.scipy.special.expit()和jax.scipy.special.logit()现在要求其参数是标量或 JAX 数组。它们现在也将整数参数提升为浮点数。已弃用
DeviceArray.tile()方法,因为 numpy 数组没有tile()方法。作为此方法的替代,请使用jax.numpy.tile()(#10266)。
jaxlib 0.3.7 (2022 年 4 月 15 日)#
更改
Linux wheels 现在按照
manylinux2014标准构建,而不是manylinux2010。
jax 0.3.6 (2022 年 4 月 12 日)#
jax 0.3.5 (2022 年 4 月 7 日)#
更改
添加了
jax.random.loggamma()并改进了jax.random.beta()和jax.random.dirichlet()对于小参数值的行为(#9906)。私有的
lax_numpy子模块不再暴露在jax.numpy命名空间中(#10029)。添加了数组创建例程
jax.numpy.frombuffer()、jax.numpy.fromfunction()和jax.numpy.fromstring()(#10049)。DeviceArray.copy()现在返回DeviceArray而不是np.ndarray(#10069)。jax.experimental.sharded_jit已被弃用,并将很快被移除。
弃用
jax.nn.normalize()已被弃用。请使用jax.nn.standardize()代替(#9899)。jax.tree_util.tree_multimap()已被弃用。请使用jax.tree_util.tree_map()代替(#5746)。jax.experimental.sharded_jit已被弃用。请使用pjit代替。
jaxlib 0.3.5 (2022 年 4 月 7 日)#
jax 0.3.4 (2022 年 3 月 18 日)#
jax 0.3.3 (2022 年 3 月 17 日)#
jax 0.3.2 (2022 年 3 月 16 日)#
更改
已弃用的函数
jax.ops.index_update、jax.ops.index_add(在 0.2.22 中弃用)已被移除。请使用 JAX 数组上的.at属性 代替,例如x.at[idx].set(y)。将
jax.experimental.ann.approx_*_k移至jax.lax。这些函数是jax.lax.top_k的优化替代品。jax.numpy.broadcast_arrays()和jax.numpy.broadcast_to()现在要求标量或类数组输入,并且在传入列表时会失败(#7737 的一部分)。标准的 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。
pjit现在可以在 CPU 上工作(除了之前的 TPU 和 GPU 支持)。
jaxlib 0.3.2 (2022 年 3 月 16 日)#
更改
XlaComputation.as_hlo_text()现在可以通过传递布尔标志print_large_constants=True来支持打印大型常量。
弃用
JAX 数组上的
.block_host_until_ready()方法已被弃用。请使用.block_until_ready()代替。
jax 0.3.1 (2022 年 2 月 18 日)#
更改
jax.test_util.JaxTestCase和jax.test_util.JaxTestLoader现在已被弃用。建议的替代方法是直接使用parametrized.TestCase。对于依赖自定义断言(如JaxTestCase.assertAllClose())的测试,建议的替代方法是使用标准的 NumPy 测试实用程序,如numpy.testing.assert_allclose(),它可以直接与 JAX 数组一起使用(#9620)。jax.test_util.JaxTestCase现在默认设置jax_numpy_rank_promotion='raise'(#9562)。要恢复之前的行为,请使用新的jax.test_util.with_config装饰器。@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
添加了
jax.scipy.linalg.schur()、jax.scipy.linalg.sqrtm()、jax.scipy.signal.csd()、jax.scipy.signal.stft()、jax.scipy.signal.welch()。
jax 0.3.0 (2022 年 2 月 10 日)#
更改
jax 版本已升级到 0.3.0。请参阅 设计文档 了解解释。
jaxlib 0.3.0 (2022 年 2 月 10 日)#
更改
构建 jaxlib 现在需要 Bazel 5.0.0。
jaxlib 版本已升级到 0.3.0。请参阅 设计文档 了解解释。
jax 0.2.28 (2022 年 2 月 1 日)#
-
默认情况下,
jax.jit(f).lower(...).compiler_ir()的输出为 MHLO 方言,如果未指定dialect=。现在,
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')返回的是 MLIRir.Module对象,而不是其字符串表示。
jaxlib 0.1.76 (2022 年 1 月 27 日)#
新功能
包含 NVIDIA 计算能力 8.0 GPU(例如 A100)的预编译 SASS。移除了计算能力 6.1 的预编译 SASS,以免增加计算能力的数量:计算能力 6.1 的 GPU 可以使用 6.0 SASS。
使用 jaxlib 0.1.76 时,JAX 默认使用 MHLO MLIR 方言作为其主要的后端编译器 IR。
重大更改
根据 弃用策略,已停止支持 NumPy 1.18。请升级到支持的 NumPy 版本。
错误修复
修复了由不同路径构建的、看似相同的 pytreedef 对象不相等的 bug (#9066)。
JAX jit 缓存需要两个静态参数具有相同的类型才能实现缓存命中 (#9311)。
jax 0.2.27 (2022 年 1 月 18 日)#
重大更改
根据 弃用策略,已停止支持 NumPy 1.18。请升级到支持的 NumPy 版本。
host_callback 原语已简化,移除了对 hcb.id_tap 和 id_print 的特殊 autodiff 处理。从现在起,只 tap primal。旧行为(在有限时间内)可以通过设置
JAX_HOST_CALLBACK_AD_TRANSFORMS环境变量或--jax_host_callback_ad_transforms标志获得。此外,还添加了关于如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。排序现在与 NumPy 在
0.0和NaN上的行为一致,无论比特表示如何。特别是,0.0和-0.0现在被视为等效,而之前-0.0被视为小于0.0。此外,所有NaN表示现在都被视为等效并排序到数组的末尾。之前负的NaN值排序到数组的前面,并且具有不同内部比特表示的NaN值不被视为等效,并根据这些比特模式进行排序(#9178)。jax.numpy.unique()现在像 NumPy 1.21 及更新版本中的np.unique一样处理NaN值:去重输出中最多会出现一个NaN值(#9184)。
错误修复
host_callback 现在支持 ad_checkpoint.checkpoint(#8907)。
新功能
添加
jax.block_until_ready({jax-issue}`#8941)添加了一个新的调试标志/环境变量
JAX_DUMP_IR_TO=/path。如果设置,JAX 将为每个计算生成的 MHLO/HLO IR 转储到给定路径下的一个文件。将
jax.ensure_compile_time_eval添加到公共 API(#7987)。jax2tf 现在支持 jax2tf_associative_scan_reductions 标志,以更改对关联约简(例如,jnp.cumsum)的降低,使其行为类似于 CPU 和 GPU 上的 JAX(使用关联扫描)。有关详细信息,请参阅 jax2tf README(#9189)。
jaxlib 0.1.75 (2021 年 12 月 8 日)#
新功能
支持 python 3.10。
jax 0.2.26 (2021 年 12 月 8 日)#
错误修复
对
jax.ops.segment_sum的越界索引将使用FILL_OR_DROP语义进行处理,如文档所述。这主要影响反向模式导数,其中对应于越界索引的梯度现在将返回 0。(#8634)。jax2tf 将强制转换后的代码在
jax.jit的代码片段下使用 XLA,例如,大多数 jax.numpy 函数(#7839)。
jaxlib 0.1.74 (2021 年 11 月 17 日)#
启用了 GPU 之间的点对点(peer-to-peer)复制。之前,GPU 复制会通过主机进行中转,这通常较慢。
添加了供 JAX 使用的实验性 MLIR Python 绑定。
jax 0.2.25 (2021 年 11 月 10 日)#
新功能
(实验性)
jax.distributed.initialize暴露了多主机 GPU 后端。jax.random.permutation支持新的independent关键字参数(#8430)。
重大更改
将
jax.experimental.stax移动到jax.example_libraries.stax。将
jax.experimental.optimizers移动到jax.example_libraries.optimizers。
新功能
添加了
jax.lax.linalg.qdwh。
jax 0.2.24 (2021 年 10 月 19 日)#
jaxlib 0.1.73 (2021 年 10 月 18 日)#
Jaxlib GPU
cuda11wheel 现在支持多个 cuDNN 版本。cuDNN 8.2 或更新版本。建议使用 cuDNN 8.2 wheel(如果您的 cuDNN 安装版本足够新),因为它支持其他功能。
cuDNN 8.0.5 或更新版本。
重大更改
GPU jaxlib 的安装命令如下:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (2021 年 10 月 12 日)#
重大更改
jax.pmap的静态参数现在必须是可哈希的。不可哈希的静态参数长期以来在
jax.jit中被禁止,但在jax.pmap中仍然允许;jax.pmap使用对象标识符比较不可哈希的静态参数。此行为是一个陷阱,因为使用对象标识符比较参数会导致每次对象标识符更改时都重新编译。相反,我们现在禁止不可哈希的参数:如果
jax.pmap的用户想要通过对象标识符比较静态参数,他们可以为其对象定义具有该行为的__hash__和__eq__方法,或者将他们的对象包装在一个具有对象标识符语义的运算的对象中。另一个选择是使用functools.partial将不可哈希的静态参数封装到函数对象中。jax.util.partial是一个意外的导出,现已移除。请使用 Python 标准库中的functools.partial代替。
弃用
函数
jax.ops.index_update、jax.ops.index_add等已被弃用,将在未来的 JAX 版本中移除。请使用 JAX 数组上的.at属性 代替,例如x.at[idx].set(y)。目前,这些函数会产生一个DeprecationWarning。
新功能
使用 jaxlib 0.1.72 或更新版本时,优化的 C++ 代码路径可以提高
pmap的调度时间。该功能可以通过--experimental_cpp_pmap标志(或JAX_CPP_PMAP环境变量)禁用。jax.numpy.unique现在支持可选的fill_value参数(#8121)。
jaxlib 0.1.72 (2021 年 10 月 12 日)#
重大更改
已停止支持 CUDA 10.2 和 CUDA 10.1。Jaxlib 现在支持 CUDA 11.1+。
错误修复
修复了 https://github.com/jax-ml/jax/issues/7461 中的 bug,该 bug 由于 XLA 编译器内部的缓冲区别名错误导致在所有平台上产生错误输出。
jax 0.2.21 (2021 年 9 月 23 日)#
重大更改
jax.api已移除。原本可作为jax.api.*访问的函数是jax.*中函数的别名;请使用jax.*中的函数。jax.partial和jax.lax.partial是意外的导出,现已移除。请使用 Python 标准库中的functools.partial代替。布尔标量索引现在会引发
TypeError;之前这会静默返回错误的结果(#7925)。许多
jax.numpy函数现在要求类数组输入,并且在传入列表时会报错(#7747 #7802 #7907)。请参阅 #7737 讨论此更改背后的原因。在转换(如
jax.jit)内部时,jax.numpy.array始终会将它生成的数组暂存到跟踪的计算中。之前,即使在jax.jit装饰器下,jax.numpy.array有时也会生成一个设备上的数组。此更改可能会破坏使用 JAX 数组进行形状或索引计算的代码,这些计算必须是静态已知的;解决方法是改用经典的 NumPy 数组执行此类计算。jnp.ndarray现在是 JAX 数组的真实基类。特别是,这意味着对于标准的 NumPy 数组x,isinstance(x, jnp.ndarray)现在将返回False(#7927)。
新功能
添加了
jax.numpy.insert()的实现(#7936)。
jax 0.2.20 (2021 年 9 月 2 日)#
重大更改
jaxlib 0.1.71 (2021 年 9 月 1 日)#
重大更改
已停止支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。
jax 0.2.19 (2021 年 8 月 12 日)#
重大更改
根据 弃用策略,已停止支持 NumPy 1.17。请升级到支持的 NumPy 版本。
在 JAX 数组的许多运算符的实现周围添加了
jit装饰器。这加快了+等常用运算符的调度时间。此更改对大多数用户来说应该基本透明。但是,有一个已知的行为变化,即大型整数常量现在可能在直接传递给 JAX 运算符时(例如,
x + 2**40)引发错误。解决方法是将常量强制转换为显式类型(例如,np.float64(2**40))。
新功能
改进了 jax2tf 对形状多态的支持,用于需要使用尺寸大小的数组计算的操作,例如
jnp.mean。(#7317)。
错误修复
上一个版本中一些泄漏的跟踪错误(#7613)。
jaxlib 0.1.70 (2021 年 8 月 9 日)#
jax 0.2.18 (2021 年 7 月 21 日)#
重大更改
根据 弃用策略,已停止支持 Python 3.6。请升级到支持的 Python 版本。
最低 jaxlib 版本现在是 0.1.69。
已移除
jax.dlpack.from_dlpack()的backend参数。
新功能
添加了极分解(
jax.scipy.linalg.polar())。
错误修复
收紧了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会与无效的
axis值或空的归约维度一起使用。(#7196)。
jaxlib 0.1.69 (2021 年 7 月 9 日)#
修复了 TFRT CPU 后端导致结果错误的 bug。
jax 0.2.17 (2021 年 7 月 9 日)#
错误修复
默认使用旧的“stream_executor”CPU 运行时,以使 jaxlib <= 0.1.68 正常工作,以解决 #7229,该问题导致 CPU 上由于并发问题而产生错误输出。
新功能
新的 SciPy 函数
jax.scipy.special.sph_harm()。反向模式自动微分函数(
jax.grad()、jax.value_and_grad()、jax.vjp()和jax.linear_transpose())支持一个参数,该参数指示在反向传播中应将哪些命名轴进行求和(如果它们在正向传播中被广播)。这使得这些 API 可以在映射(最初仅jax.experimental.maps.xmap())内部以非逐例(non-per-example)方式使用(#6950)。
jax 0.2.16 (2021 年 6 月 23 日)#
jax 0.2.15 (2021 年 6 月 23 日)#
新功能
#7042 开启了 TFRT CPU 后端,在 CPU 上实现了显著的调度性能改进。
jax2tf.convert()支持布尔值的不等式和 min/max(#6956)。新的 SciPy 函数
jax.scipy.special.lpmn_values()。
重大更改
根据 弃用策略,已停止支持 NumPy 1.16。请升级到支持的 NumPy 版本。
错误修复
修复了阻止 JAX 到 TF 再回到 JAX 的往返的 bug:
jax2tf.call_tf(jax2tf.convert)(#6947)。
jaxlib 0.1.68 (2021 年 6 月 23 日)#
错误修复
修复了 TFRT CPU 后端中将 TPU buffer 传输到 CPU 时导致 NaN 的 bug。
jax 0.2.14 (2021 年 6 月 10 日)#
新功能
jax2tf.convert()现在支持pjit和sharded_jit。一个新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯。
在足够新的 IPython 版本中,默认启用了使用
__tracebackhide__的新回溯过滤模式。jax2tf.convert()支持形状多态,即使未知维度用于算术运算(例如,jnp.reshape(-1))(#6827)。jax2tf.convert()在 TF op 中生成具有位置信息的自定义属性。jax2tf 转换后 XLA 生成的代码与 JAX/XLA 具有相同的位置信息。新的 SciPy 函数
jax.scipy.special.lpmn()。
错误修复
jaxlib 0.1.67 (2021 年 5 月 17 日)#
jaxlib 0.1.66 (2021 年 5 月 11 日)#
新功能
现在所有 CUDA 11 版本 11.1 或更高版本都支持 CUDA 11.1 wheel。
Nvidia 现在承诺从 CUDA 11.1 开始,CUDA 次要版本之间兼容。这意味着 JAX 可以发布一个 CUDA 11.1 wheel,兼容 CUDA 11.2 和 11.3。
不再为 CUDA 11.2(或更高版本)单独发布 jaxlib;请为这些版本使用 CUDA 11.1 wheel (cuda111)。
Jaxlib 现在在 CUDA wheel 中捆绑了
libdevice.10.bc。不再需要将 JAX 指向 CUDA 安装来查找此文件。在
jit()实现中添加了对静态关键字参数的自动支持。添加了对预转换异常跟踪的支持。
初步支持从
jit()转换后的计算中修剪未使用的参数。修剪仍在进行中。改进了
PyTreeDef对象的字符串表示。添加了对 XLA 的可变 ReduceWindow 的支持。
错误修复
修复了当大量参数传递给计算时,远程云 TPU 支持中的一个 bug。
修复了一个 bug,该 bug 导致
jit()转换的函数不会触发 JAX 的垃圾回收。
jax 0.2.13 (2021 年 5 月 3 日)#
新功能
与 jaxlib 0.1.66 结合使用时,
jax.jit()现在支持静态关键字参数。添加了一个新的static_argnames选项来指定关键字参数为静态。jax.nonzero()有一个新的可选size参数,允许它在jit中使用(#6501)jax.numpy.unique()现在支持axis参数(#6532)。jax.experimental.host_callback.call()现在支持pjit.pjit(#6569)。添加了
jax.scipy.linalg.eigh_tridiagonal(),它计算三对角矩阵的特征值。目前仅支持特征值。异常中过滤和未过滤的堆栈跟踪的顺序已更改。从 JAX 转换的代码中抛出的异常附加的跟踪现在是过滤后的,一个
UnfilteredStackTrace异常作为过滤后异常的__cause__包含原始跟踪。过滤后的堆栈跟踪现在也支持 Python 3.6。如果异常是由反向模式自动微分转换的代码引起的,JAX 现在会尝试将一个
JaxStackTraceBeforeTransformation对象作为异常的__cause__附加,该对象包含创建前向传递中原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
重大更改
以下函数名已更改。仍保留别名,因此不应破坏现有代码,但别名最终将被删除,请更改您的代码。
host_id–>process_index()host_count–>process_count()host_ids–>range(jax.process_count())
类似地,
local_devices()的参数已从host_id重命名为process_index。jax.jit()的参数(除了函数本身)现在被标记为仅关键字参数。此更改是为了防止在向jit添加参数时发生意外中断。
错误修复
jaxlib 0.1.65 (2021 年 4 月 7 日)#
jax 0.2.12 (2021 年 4 月 1 日)#
新功能
新的分析 API:
jax.profiler.start_trace()、jax.profiler.stop_trace()和jax.profiler.trace()jax.lax.reduce()现在可微分。
重大更改
最低 jaxlib 版本现为 0.1.64。
一些分析 API 的名称已更改。仍保留别名,因此不应破坏现有代码,但别名最终将被删除,请更改您的代码。
TraceContext–>TraceAnnotation()StepTraceContext–>StepTraceAnnotation()trace_function–>annotate_function()
Omnistaging 已无法禁用。有关更多信息,请参阅 omnistaging。
大于最大
int64值的 Python 整数现在将在所有情况下导致溢出,而不是在某些情况下被静默转换为uint64(#6047)。在 X64 模式之外,超出
int32可表示范围的 Python 整数现在将导致OverflowError,而不是静默截断其值。
错误修复
jax 0.2.11 (2021 年 3 月 23 日)#
新功能
错误修复
重大更改
最低 jaxlib 版本现为 0.1.62。
jaxlib 0.1.64 (2021 年 3 月 18 日)#
jaxlib 0.1.63 (2021 年 3 月 17 日)#
jax 0.2.10 (2021 年 3 月 5 日)#
新功能
错误修复
jax.numpy.take()正确处理负索引(#5768)
重大更改
JAX 的提升规则已调整,以使提升更加一致且对 JIT 不变。特别是,二元运算现在可以在适当的情况下产生弱类型值。此更改的主要用户可见效果是某些运算产生的输出精度与之前不同;例如,表达式
jnp.bfloat16(1) + 0.1 * jnp.arange(10)以前返回一个float64数组,现在返回一个bfloat16数组。JAX 的类型提升行为在 类型提升语义 中进行描述。jax.numpy.linspace()现在计算整数值的 floor(即向 -inf 舍入而不是向 0 舍入)。此更改是为了匹配 NumPy 1.20.0。jax.numpy.i0()不再接受复数。以前,该函数计算复数参数的绝对值。此更改是为了匹配 NumPy 1.20.0 的语义。几个
jax.numpy函数不再接受元组或列表作为数组参数:jax.numpy.pad(), :funcjax.numpy.ravel,jax.numpy.repeat(),jax.numpy.reshape()。总的来说,jax.numpy函数应该与标量或数组参数一起使用。
jaxlib 0.1.62 (2021 年 3 月 9 日)#
新功能
jaxlib wheel 现在默认构建为需要 x86-64 机器上的 AVX 指令。如果您想在不支持 AVX 的机器上使用 JAX,可以使用
build.py的--target_cpu_features标志从源构建 jaxlib。--target_cpu_features也取代了--enable_march_native。
jaxlib 0.1.61 (2021 年 2 月 12 日)#
jaxlib 0.1.60 (2021 年 2 月 3 日)#
错误修复
修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏。该内存泄漏存在于 jaxlib 版本 0.1.58 和 0.1.59 中。
bool、int8和uint8现在被认为是安全地转换为bfloat16NumPy 扩展类型的。
jax 0.2.9 (2021 年 1 月 26 日)#
新功能
使用支持 pytrees 扩展了
jax.experimental.loops模块。改进了错误检查和错误消息。添加了
jax.experimental.enable_x64()和jax.experimental.disable_x64()。这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
重大更改
jax.ops.segment_sum()现在会丢弃超出范围的 segment IDs,而不是将它们包装到 segment ID 空间中。这是出于性能原因。
jaxlib 0.1.59 (2021 年 1 月 15 日)#
jax 0.2.8 (2021 年 1 月 12 日)#
新功能
为与高阶自定义导数函数一起使用添加了
jax.closure_convert()。(#5244)添加了
jax.experimental.host_callback.call()以调用自定义 Python 函数并在主机上将结果返回到设备计算。(#5243)
错误修复
重大更改
jax.numpy.pad现在接受关键字参数。已删除位置参数constant_values。此外,传递不支持的关键字参数会引发错误。jax.experimental.host_callback.id_tap()的更改(#5243)删除了对
jax.experimental.host_callback.id_tap()的kwargs的支持。(此支持已弃用几个月。)更改了
jax.experimental.host_callback.id_print()的元组打印方式,使用 ‘(‘ 而不是 ‘[‘。在 JVP 存在的情况下,更改了
jax.experimental.host_callback.id_print()以打印原始值和切线值对。以前,原始值和切线值有两个单独的打印操作。host_callback.outfeed_receiver已删除(它不是必需的,并且几个月前已被弃用)。
新功能
用于调试
inf的新标志,类似于NaN的标志(#5224)。
jax 0.2.7 (2020 年 12 月 4 日)#
新功能
添加
jax.device_put_replicated为
jax.experimental.sharded_jit添加了多主机支持添加了对
jax.numpy.linalg.eig计算的特征值进行微分的支持添加了对 Windows 平台构建的支持
为
jax.pmap添加了in_axes和out_axes的通用支持为
jax.numpy.linalg.slogdet添加了复数支持
错误修复
修复了在零点处
jax.numpy.sinc的高阶(高于二阶)导数修复了转置规则中与符号零相关的难以触及的一些 bug
重大更改
jax.experimental.optix已删除,取而代之的是独立的optaxPython 包。使用非元组序列对 JAX 数组进行索引现在会引发
TypeError。此类型的索引在 Numpy 中自 v1.16 起已弃用,在 JAX 中自 v0.2.4 起已弃用。请参阅 #4564。
jax 0.2.6 (2020 年 11 月 18 日)#
新功能
为 jax.experimental.jax2tf 转换器添加了对形状多态跟踪的支持。请参阅 README.md。
破坏性更改清理
对 jax.jit 和 xla_computation 的不可哈希静态参数引发错误。请参阅 cb48f42。
提高类型提升行为的一致性(#4744)
将复杂的 Python 标量添加到 JAX 浮点数会尊重 JAX 浮点数的精度。例如,
jnp.float32(1) + 1j现在返回complex64,而以前返回complex128。涉及 uint64、有符号整数和第三种类型的 3 个或更多项的类型提升结果现在与参数顺序无关。例如:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)都返回float16,而以前第一个返回float64,第二个返回float16。
(未记录的)
jax.lax_linalg线性代数模块的内容现在公开为jax.lax.linalg。jax.random.PRNGKey在 JIT 编译内外现在产生相同的结果(#4877)。这需要更改在某些特定情况下的结果使用
jax_enable_x64=False时,负数种子(作为 Python 整数传递)在 JIT 模式外现在返回不同的结果。例如,jax.random.PRNGKey(-1)以前返回[4294967295, 4294967295],现在返回[0, 4294967295]。这与 JIT 中的行为匹配。在 JIT 外,超出
int64可表示范围的种子现在会导致OverflowError而不是TypeError。这与 JIT 中的行为匹配。
要恢复 JIT 外
jax_enable_x64=False时以前为负整数返回的键,可以使用key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray 现在在尝试访问已被删除的值时引发
RuntimeError而不是ValueError。
jaxlib 0.1.58 (2021 年 1 月 12 日左右)#
修复了一个 bug,该 bug 导致 JAX 有时会返回特定于平台的类型(例如
np.cint)而不是标准类型(例如np.int32)。(#4903)修复了在常量折叠某些 int16 操作时崩溃的问题。(#4971)
在
pytree.flatten()中添加了一个is_leafpredicate。
jaxlib 0.1.57 (2020 年 11 月 12 日)#
修复了 GPU wheel 中的 manylinux2010 合规性问题。
将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
修复了一个 bug,该 bug 导致 bfloat16 值的哈希未正确初始化并可能改变(#4651)。
添加了在将数组传递给 DLPack 时保留所有权的支持(#4636)。
修复了大小大于 128 但不是 128 倍数的批处理三角解的 bug。
修复了在多个 GPU 上并发执行 FFT 时出现的 bug(#3518)。
修复了分析器中工具缺失的 bug(#4427)。
放弃了对 CUDA 10.0 的支持。
jax 0.2.5 (2020 年 10 月 27 日)#
改进
确保
check_jaxpr不执行 FLOPS。请参阅 #4650。扩展了 jax2tf 转换的 JAX 原语集。请参阅 primitives_with_limited_support.md。
jax 0.2.4 (2020 年 10 月 19 日)#
jaxlib 0.1.56 (2020 年 10 月 14 日)#
jax 0.2.3 (2020 年 10 月 14 日)#
之所以这么快发布另一个版本,是因为我们需要暂时回滚一个新的 jit 快速路径,同时我们正在调查性能下降问题。
jax 0.2.2 (2020 年 10 月 13 日)#
jax 0.2.1 (2020 年 10 月 6 日)#
改进
作为 omnistaging 的一个好处,即使
jax.experimental.host_callback.id_print()/jax.experimental.host_callback.id_tap()的结果未在计算中使用,host_callback 函数也会(按程序顺序)执行。
jax (0.2.0) (2020 年 9 月 23 日)#
改进
Omnistaging 默认开启。请参阅 #3370 和 omnistaging
jax (0.1.77) (2020 年 9 月 15 日)#
重大更改
jax.experimental.host_callback.id_tap()的新简化接口(#4101)
jaxlib 0.1.55 (2020 年 9 月 8 日)#
更新 XLA
修复了 DLPackManagedTensorToBuffer 中的 bug(#4196)
jax 0.1.76 (2020 年 9 月 8 日)#
jax 0.1.75 (2020 年 7 月 30 日)#
Bug 修复
使 jnp.abs() 可用于无符号输入(#3914)
改进
“Omnistaging”行为通过标志启用,默认禁用(#3370)
jax 0.1.74 (2020 年 7 月 29 日)#
新功能
BFGS(#3101)
TPU 支持半精度算术(#3878)
Bug 修复
防止一些意外的 dtype 警告(#3874)
修复了自定义导数中的多线程 bug(#3845,#3869)
改进
更快的 searchsorted 实现(#3873)
改进了 jax.numpy 排序算法的测试覆盖率(#3836)
jaxlib 0.1.52 (2020 年 7 月 22 日)#
更新 XLA。
jax 0.1.73 (2020 年 7 月 22 日)#
最低 jaxlib 版本现为 0.1.51。
新功能
jax.image.resize。(#3703)
hfft 和 ihfft(#3664)
jax.numpy.intersect1d(#3726)
jax.numpy.lexsort(#3812)
lax.scan和scan原语支持unroll参数,用于在降低到 XLA 时进行循环展开(#3738)。
Bug 修复
修复了 reduction 重复轴的错误(#3618)
修复了 lax.pad 对于尺寸为 0 的输入维度的形状规则。(#3608)
使 psum 转置处理零余切(#3653)
修复了在对大小为 0 的轴进行 reduce-prod 的 JVP 时出现的形状错误。(#3729)
支持对 jax.lax.all_to_all 进行微分(#3733)
解决了 jax.scipy.special.zeta 中的 nan 问题(#3777)
改进
对 jax2tf 进行大量改进
使用单个传递的可变规约重新实现 argmin/argmax。(#3611)
默认启用 XLA SPMD 分区。(#3151)
添加了对 0d 转置卷积的支持(#3643)
使 LU 梯度适用于低秩矩阵(#3610)
支持 multiple_results 和 jet 中的自定义 JVP(#3657)
将 reduce-window 填充泛化为支持 (lo, hi) 对。(#3728)
在 CPU 和 GPU 上实现复数卷积。(#3735)
使 jnp.take 在处理空数组的空切片时工作。(#3751)
放宽了 dot_general 的维度顺序规则。(#3778)
为 GPU 启用缓冲区捐赠。(#3800)
将基值膨胀和窗口膨胀添加到 reduce window op…(#3803)
jaxlib 0.1.51 (2020 年 7 月 2 日)#
更新 XLA。
为 host_callback 添加了新的运行时支持。
jax 0.1.72 (2020 年 6 月 28 日)#
错误修复
修复了上一版本中引入的 odeint bug,请参阅 #3587。
jax 0.1.71 (2020 年 6 月 25 日)#
最低 jaxlib 版本现为 0.1.48。
错误修复
允许
jax.experimental.ode.odeint动态函数在区分关于时间时关闭具有时间依赖性的值 #3562。
jaxlib 0.1.50 (2020 年 6 月 25 日)#
添加了对 CUDA 11.0 的支持。
放弃了对 CUDA 9.2 的支持(我们只维护对最后四个 CUDA 版本的支持)。
更新 XLA。
jaxlib 0.1.49 (2020 年 6 月 19 日)#
错误修复
修复了可能导致编译缓慢的构建问题(tensorflow/tensorflow)
jaxlib 0.1.48 (2020 年 6 月 12 日)#
新功能
添加了对快速堆栈跟踪收集的支持。
添加了对设备上堆分析的初步支持。
为
bfloat16类型实现了np.nextafter。CPU 和 GPU 上 FFT 的 Complex128 支持。
错误修复
提高了 GPU 上 float64
tanh的精度。GPU 上的 float64 散点图速度更快。
CPU 上的复数矩阵乘法速度应该快得多。
CPU 上的稳定排序现在应该是真正稳定的。
CPU 后端的并发 bug 修复。
jax 0.1.70 (2020 年 6 月 8 日)#
新功能
lax.switch引入了带有多个分支的索引条件,并对cond原语进行了泛化 #3318。
jax 0.1.69 (2020 年 6 月 3 日)#
jax 0.1.68 (2020 年 5 月 21 日)#
jax 0.1.67 (2020 年 5 月 12 日)#
新功能
使用
axis_index_groups对 pmapped 轴的子集进行规约的支持 #2382。实验性支持从编译代码打印和调用主机端 Python 函数。请参阅 id_print 和 id_tap(#3006)。
值得注意的更改
从
jax.numpy导出的名称的可访问性已得到加强。这可能会破坏使用了之前意外导出的名称的代码。
jaxlib 0.1.47 (2020 年 5 月 8 日)#
修复了 outfeed 崩溃。
jax 0.1.66 (2020 年 5 月 5 日)#
新功能
支持
pmap()上的in_axes=None#2896。
jaxlib 0.1.46 (2020 年 5 月 5 日)#
修复了在存在不同型号的多个 GPU 时,JAX 只编译适用于第一个 GPU 的程序的崩溃问题(#432)。
修复了在使用操作系统或虚拟机管理程序禁用了 AVX512 指令时,由使用 AVX512 指令引起的非法指令崩溃。(#2906)
jax 0.1.65 (2020 年 4 月 30 日)#
新功能
奇数矩阵行列式的微分 #2809。
错误修复
jaxlib 0.1.45 (2020 年 4 月 21 日)#
修复了段错误:#2755
将
is_stable选项从 Sort HLO 传递到 Python。
jax 0.1.64 (2020 年 4 月 21 日)#
新功能
为函数式索引更新添加了语法糖 #2684。
添加了
jax.numpy.unique()#2760。添加了
jax.numpy.rint()#2724。添加了
jax.numpy.rint()#2724。为
jax.experimental.jet()添加了更多原始规则。
错误修复
更好的错误
改进了
lax.while_loop()反向模式微分的错误消息 #2129。
jaxlib 0.1.44 (2020 年 4 月 16 日)#
修复了一个 bug,该 bug 导致如果存在多个不同型号的 GPU,JAX 只会编译适用于第一个 GPU 的程序。
batch_group_count卷积的 bug 修复。为更多 GPU 版本添加了预编译的 SASS,以避免启动 PTX 编译挂起。
jax 0.1.63 (2020 年 4 月 12 日)#
添加了
jax.custom_jvp和jax.custom_vjp(#2026),请参阅 教程 notebook。弃用了jax.custom_transforms并从文档中删除了它(尽管它仍然有效)。添加了
scipy.sparse.linalg.cg#2566。更改了 Tracers 的打印方式,以显示更有用的调试信息 #2591。
使
jax.numpy.isclose正确处理nan和inf#2501。为
jax.experimental.jet添加了几个新的规则 #2537。修复了
jax.experimental.stax.BatchNorm在未提供scale/center时的问题。修复了
jax.numpy.einsum中一些缺失的广播情况 #2512。使用并行前缀扫描实现了
jax.numpy.cumsum和jax.numpy.cumprod#2596,并使reduce_prod可任意阶微分 #2597。将
batch_group_count添加到conv_general_dilated#2635。为
test_util.check_grads添加了文档字符串 #2656。添加了
callback_transform#2665。实现了
rollaxis、convolve/correlate1d & 2d、copysign、trunc、roots以及quantile/percentile插值选项。
jaxlib 0.1.43 (2020 年 3 月 31 日)#
修复了 GPU 上 Resnet-50 的性能回归。
jax 0.1.62 (2020 年 3 月 21 日)#
JAX 已放弃对 Python 3.5 的支持。请升级到 Python 3.6 或更高版本。
删除了内部函数
lax._safe_mul,该函数实现了0. * nan == 0.的约定。此更改意味着某些程序在微分时会产生 nan,而以前会产生正确的值,但它确保了其他程序不会产生静默不正确的结果,而是会产生 nan。有关详细信息,请参阅 #2447 和 #1052。添加了一个
all_gather并行便利函数。核心代码中的更多类型注解。
jaxlib 0.1.42 (2020 年 3 月 19 日)#
jaxlib 0.1.41 由于 API 不兼容而破坏了云 TPU 支持。此版本已重新修复。
JAX 已放弃对 Python 3.5 的支持。请升级到 Python 3.6 或更高版本。
jax 0.1.61 (2020 年 3 月 17 日)#
修复了 Python 3.5 支持。这将是最后一个支持 Python 3.5 的 JAX 或 jaxlib 版本。
jax 0.1.60 (2020 年 3 月 17 日)#
新功能
jax.pmap()具有static_broadcast_argnums参数,该参数允许用户指定应视为编译时常量并应广播到所有设备的参数。它的工作方式类似于jax.jit()中的static_argnums。改进了当 Tracers 错误地保存在全局状态时出现的错误消息。
添加了
jax.nn.one_hot()实用函数。添加了
jax.experimental.jet以实现指数级更快的二阶及更高阶自动微分。为
jax.lax.broadcast_in_dim()的参数添加了更多正确性检查。
最低 jaxlib 版本现为 0.1.41。
jaxlib 0.1.40 (2020 年 3 月 4 日)#
在 Jaxlib 中添加了对 TensorFlow 分析器的实验性支持,该支持允许从 TensorBoard 跟踪 CPU 和 GPU 计算。
包括了多主机 GPU 计算(通过 NCCL 通信)的原型支持。
提高了 GPU 上 NCCL 集合通信的性能。
添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。
支持在 XLA 编译时已知的设备分配。
jax 0.1.59 (2020 年 2 月 11 日)#
重大更改
最低 jaxlib 版本现为 0.1.38。
简化了
Jaxpr,移除了Jaxpr.freevars和Jaxpr.bound_subjaxprs。调用原语(xla_call、xla_pmap、sharded_call和remat_call)获得了一个新参数call_jaxpr,其中包含一个完全闭合的(无constvars)jaxpr。此外,在原语中添加了一个新字段call_primitive。
新功能
反向模式自动微分(例如
grad)对lax.cond的微分,使其现在可以双向微分(#2091)JAX 现在支持 DLPack,这允许与 PyTorch 等其他库以零拷贝方式共享 CPU 和 GPU 数组。
JAX GPU DeviceArrays 现在支持
__cuda_array_interface__,这是另一种与 CuPy 和 Numba 等其他库共享 GPU 数组的零拷贝协议。JAX CPU 设备缓冲区现在实现了 Python 缓冲区协议,这允许 JAX 和 NumPy 之间进行零拷贝缓冲区共享。
添加了 JAX_SKIP_SLOW_TESTS 环境变量以跳过已知较慢的测试。
jaxlib 0.1.39 (2020 年 2 月 11 日)#
更新 XLA。
jaxlib 0.1.38 (2020 年 1 月 29 日)#
不再支持 CUDA 9.0。
默认情况下现在构建 CUDA 10.2 wheel。
jax 0.1.58 (2020 年 1 月 28 日)#
重大更改
JAX 已放弃对 Python 2 的支持,因为 Python 2 已于 2020 年 1 月 1 日结束生命周期。请更新到 Python 3.5 或更高版本。
新功能
while 循环的前向模式自动微分(
jvp)(#1980)
新的 NumPy 和 SciPy 函数
GPU 上的批处理 Cholesky 分解现在使用更高效的批处理内核。
值得注意的 bug 修复#
通过升级到 Python 3,JAX 不再依赖于
fastcache,这应该有助于安装。