公共 API:jax#

子包#

配置#

config

check_tracer_leaks

jax_check_tracer_leaks 配置选项的上下文管理器。

checking_leaks

jax_check_tracer_leaks 配置选项的上下文管理器。

debug_nans

jax_debug_nans 配置选项的上下文管理器。

debug_infs

jax_debug_infs 配置选项的上下文管理器。

default_device

jax_default_device 配置选项的上下文管理器。

default_matmul_precision

jax_default_matmul_precision 配置选项的上下文管理器。

default_prng_impl

jax_default_prng_impl 配置选项的上下文管理器。

enable_checks

jax_enable_checks 配置选项的上下文管理器。

enable_custom_prng

jax_enable_custom_prng 配置选项的上下文管理器 (瞬态)。

enable_custom_vjp_by_custom_transpose

jax_enable_custom_vjp_by_custom_transpose 配置选项的上下文管理器 (瞬态)。

log_compiles

jax_log_compiles 配置选项的上下文管理器。

no_tracing

jax_no_tracing 配置选项的上下文管理器。

numpy_rank_promotion

jax_numpy_rank_promotion 配置选项的上下文管理器。

transfer_guard(new_val)

一个上下文管理器,用于控制所有传输的传输守卫级别。

即时编译 (jit)#

jit(fun, /, *[, in_shardings, ...])

fun 设置使用 XLA 进行即时编译。

disable_jit([disable])

在其动态上下文中禁用 jit() 行为的上下文管理器。

ensure_compile_time_eval()

上下文管理器,用于确保在跟踪/编译时进行评估(或错误)。

make_jaxpr([axis_env, return_shape, ...])

创建一个函数,该函数返回给定示例参数的 fun 的 jaxpr。

eval_shape(fun, *args, **kwargs)

计算 fun 的形状/dtype,而无需任何 FLOPs。

ShapeDtypeStruct(shape, dtype, *[, ...])

用于数组的形状、dtype 和其他静态属性的容器。

device_put(x[, device, src, donate, may_alias])

x 传输到 device

device_get(x)

x 传输到主机。

default_backend()

返回默认 XLA 后端的平台名称。

named_call(fun, *[, name])

在暂存 JAX 计算时,向函数添加用户指定的名称。

named_scope(name)

一个上下文管理器,用于将用户指定的名称添加到 JAX 名称堆栈。

block_until_ready(x)

尝试在 pytree 叶子上调用 block_until_ready 方法。

copy_to_host_async(x)

尝试在 pytree 叶子上调用 copy_to_host_async 方法。

make_mesh(axis_shapes, axis_names, *[, ...])

创建具有指定形状和轴名称的高效网格。

自动微分#

grad(fun[, argnums, has_aux, holomorphic, ...])

创建一个函数,该函数评估 fun 的梯度。

value_and_grad(fun[, argnums, has_aux, ...])

创建一个函数,该函数同时评估 funfun 的梯度。

jacobian(fun[, argnums, has_aux, ...])

jax.jacrev() 的别名。

jacfwd(fun[, argnums, has_aux, holomorphic])

使用前向模式 AD 逐列评估 fun 的 Jacobian 矩阵。

jacrev(fun[, argnums, has_aux, holomorphic, ...])

使用反向模式 AD 逐行评估 fun 的 Jacobian 矩阵。

hessian(fun[, argnums, has_aux, holomorphic])

fun 的 Hessian 矩阵,以密集数组形式表示。

jvp(fun, primals, tangents[, has_aux])

计算 fun 的(前向模式) Jacobian 矩阵-向量积。

linearize()

使用 jvp() 和部分评估生成 fun 的线性近似。

linear_transpose(fun, *primals[, reduce_axes])

转置一个承诺为线性的函数。

vjp() ))

计算 fun 的(反向模式)向量-Jacobian 矩阵积。

custom_gradient(fun)

用于定义自定义 VJP 规则(又名自定义梯度)的便捷函数。

closure_convert(fun, *example_args)

闭包转换实用程序,用于高阶自定义导数。

checkpoint(fun, *[, prevent_cse, policy, ...])

使 fun 在微分时重新计算内部线性化点。

自定义#

custom_jvp#

custom_jvp(fun[, nondiff_argnums])

为自定义 JVP 规则定义设置 JAX 可转换函数。

custom_jvp.defjvp(jvp[, symbolic_zeros])

为此实例表示的函数定义自定义 JVP 规则。

custom_jvp.defjvps(*jvps)

用于为每个参数单独定义 JVP 的便捷封装器。

custom_vjp#

custom_vjp(fun[, nondiff_argnums])

为自定义 VJP 规则定义设置 JAX 可转换函数。

custom_vjp.defvjp(fwd, bwd[, ...])

为此实例表示的函数定义自定义 VJP 规则。

custom_batching#

custom_batching.custom_vmap(fun)

自定义 JAX 可转换函数的 vmap 行为。

custom_batching.custom_vmap.def_vmap(vmap_rule)

为此 custom_vmap 函数定义 vmap 规则。

custom_batching.sequential_vmap(f)

使用循环的 custom_vmap 的特殊情况。

jax.Array (jax.Array)#

Array()

JAX 的 Array 基类

make_array_from_callback(shape, sharding, ...)

通过从 data_callback 获取的数据返回 jax.Array

make_array_from_single_device_arrays(shape, ...)

从一系列 jax.Array 返回 jax.Array,每个 jax.Array 都在单个设备上。

make_array_from_process_local_data(sharding, ...)

使用进程中可用的数据创建分布式张量。

Array 属性和方法#

Array.addressable_shards

可寻址分片列表。

Array.all([axis, out, keepdims, where])

测试给定轴上的所有数组元素是否都评估为 True。

Array.any([axis, out, keepdims, where])

测试给定轴上的任何数组元素是否评估为 True。

Array.argmax([axis, out, keepdims])

返回最大值的索引。

Array.argmin([axis, out, keepdims])

返回最小值的索引。

Array.argpartition(kth[, axis])

返回部分排序数组的索引。

Array.argsort([axis, kind, order, stable, ...])

返回排序数组的索引。

Array.astype(dtype[, copy, device])

复制数组并转换为指定的 dtype。

Array.at

用于索引更新功能的辅助属性。

Array.choose(choices[, out, mode])

构造一个从多个数组的元素中选择的数组。

Array.clip([min, max])

返回一个值限制在指定范围内的数组。

Array.compress(condition[, axis, out, size, ...])

返回沿给定轴的此数组的选定切片。

Array.committed

数组是否已提交。

Array.conj()

返回数组的复共轭。

Array.conjugate()

返回数组的复共轭。

Array.copy()

返回数组的副本。

Array.copy_to_host_async()

异步将 Array 复制到主机。

Array.cumprod([axis, dtype, out])

返回数组的累积积。

Array.cumsum([axis, dtype, out])

返回数组的累积和。

Array.device

Array API 兼容的设备属性。

Array.diagonal([offset, axis1, axis2])

从数组返回指定的对角线。

Array.dot(b, *[, precision, ...])

计算两个数组的点积。

Array.dtype

数组的数据类型 (numpy.dtype)。

Array.flat

请改用 flatten()

Array.flatten([order])

将数组展平为一维形状。

Array.global_shards

全局分片列表。

Array.imag

返回数组的虚部。

Array.is_fully_addressable

此 Array 是否完全可寻址?

Array.is_fully_replicated

此 Array 是否完全复制?

Array.item(*args)

将数组的元素复制到标准 Python 标量并返回。

Array.itemsize

一个数组元素以字节为单位的长度。

Array.max([axis, out, keepdims, initial, where])

返回给定轴上数组元素的最大值。

Array.mean([axis, dtype, out, keepdims, where])

返回给定轴上数组元素的平均值。

Array.min([axis, out, keepdims, initial, where])

返回给定轴上数组元素的最小值。

Array.nbytes

数组元素消耗的总字节数。

Array.ndim

数组中的维度数。

Array.nonzero(*[, fill_value, size])

返回数组的非零元素的索引。

Array.prod([axis, dtype, out, keepdims, ...])

返回给定轴上数组元素的乘积。

Array.ptp([axis, out, keepdims])

返回沿给定轴的峰峰值范围。

Array.ravel([order])

将数组展平为一维形状。

Array.real

返回数组的实部。

Array.repeat(repeats[, axis, ...])

从重复元素构造数组。

Array.reshape(*args[, order])

返回包含相同数据但具有新形状的数组。

Array.round([decimals, out])

将数组元素四舍五入到给定的小数位。

Array.searchsorted(v[, side, sorter, method])

在排序数组中执行二进制搜索。

Array.shape

数组的形状。

Array.sharding

数组的分片。

Array.size

数组中元素的总数。

Array.sort([axis, kind, order, stable, ...])

返回数组的排序副本。

Array.squeeze([axis])

从数组中删除一个或多个长度为 1 的轴。

Array.std([axis, dtype, out, ddof, ...])

计算沿给定轴的标准差。

Array.sum([axis, dtype, out, keepdims, ...])

给定轴上数组元素的总和。

Array.swapaxes(axis1, axis2)

交换数组的两个轴。

Array.take(indices[, axis, out, mode, ...])

从数组中获取元素。

Array.to_device(device, *[, stream])

返回指定设备上数组的副本

Array.trace([offset, axis1, axis2, dtype, out])

返回沿对角线元素的和。

Array.transpose(*args)

返回轴转置的数组副本。

Array.var([axis, dtype, out, ddof, ...])

计算沿给定轴的方差。

Array.view([dtype, type])

返回数组的按位副本,并将其视为新的 dtype。

Array.T

计算全轴数组转置。

Array.mT

计算(批量)矩阵转置。

向量化 (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

向量化映射。

numpy.vectorize(pyfunc, *[, excluded, signature])

定义一个带有广播机制的向量化函数。

并行化 (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

支持集合运算的并行映射。

devices([backend])

返回给定后端的所有设备列表。

local_devices([process_index, backend, host_id])

类似于 jax.devices(),但仅返回给定进程本地的设备。

process_index([backend])

返回此进程的整数进程索引。

device_count([backend])

返回设备总数。

local_device_count([backend])

返回此进程可寻址的设备数量。

process_count([backend])

返回与后端关联的 JAX 进程数。

process_indices([backend])

返回与后端关联的所有 JAX 进程索引的列表。

回调#

pure_callback(callback, result_shape_dtypes, ...)

调用纯 Python 回调。

experimental.io_callback(callback, ...[, ...])

调用不纯 Python 回调。

debug.callback(callback, *args[, ordered, ...])

调用可分阶段的 Python 回调。

debug.print(fmt, *args[, ordered, partitioned])

打印值并在分阶段输出的 JAX 函数中工作。

其他#

设备

可用设备的描述符。

print_environment_info([return_string])

返回包含本地环境和 JAX 安装信息的字符串。

live_arrays([platform])

返回 platform 后端中的所有活动数组。

clear_caches()

清除所有编译和暂存缓存。