jax.lax 模块#

jax.lax 是一个原始操作库,支持 jax.numpy 等库。转换规则(如 JVP 和批处理规则)通常定义为 jax.lax 原始操作的转换。

许多原始操作是等效 XLA 操作的薄封装,这些操作由 XLA 操作语义 文档描述。在少数情况下,JAX 与 XLA 不同,通常是为了确保操作集在 JVP 和转置规则的操作下是封闭的。

在可能的情况下,优先使用 jax.numpy 等库,而不是直接使用 jax.laxjax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,更不容易改变。

运算符#

abs(x)

逐元素的绝对值: \(|x|\)

acos(x)

逐元素的反余弦: \(\mathrm{acos}(x)\)

acosh(x)

逐元素的反双曲余弦: \(\mathrm{acosh}(x)\)

add(x, y)

逐元素的加法: \(x + y\)

after_all(*operands)

合并一个或多个 XLA 令牌值。

approx_max_k(operand, k[, ...])

近似返回 operand 的最大 k 值及其索引。

approx_min_k(operand, k[, ...])

近似返回 operand 的最小 k 值及其索引。

argmax(operand, axis, index_dtype)

计算沿 axis 的最大元素的索引。

argmin(operand, axis, index_dtype)

计算沿 axis 的最小元素的索引。

asin(x)

逐元素的反正弦: \(\mathrm{asin}(x)\)

asinh(x)

逐元素的反双曲正弦: \(\mathrm{asinh}(x)\)

atan(x)

逐元素的反正切: \(\mathrm{atan}(x)\)

atan2(x, y)

逐元素的双参数反正切: \(\mathrm{atan}({x \over y})\)

atanh(x)

逐元素的反双曲正切: \(\mathrm{atanh}(x)\)

batch_matmul(lhs, rhs[, precision])

批量矩阵乘法。

bessel_i0e(x)

0阶指数缩放修正贝塞尔函数: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\)

bessel_i1e(x)

1阶指数缩放修正贝塞尔函数: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\)

betainc(a, b, x)

逐元素的正则不完全贝塔积分。

bitcast_convert_type(operand, new_dtype)

逐位转换。

bitwise_and(x, y)

逐位与: \(x \wedge y\)

bitwise_not(x)

逐位非: \(\neg x\)

bitwise_or(x, y)

逐位或: \(x \vee y\)

bitwise_xor(x, y)

逐位异或: \(x \oplus y\)

population_count(x)

逐元素的位计数,计算每个元素中设置的位数。

broadcast(operand, sizes, *[, out_sharding])

广播一个数组,添加新的前导维度

broadcast_in_dim(operand, shape, ...[, ...])

封装了 XLA 的 BroadcastInDim 运算符。

broadcast_shapes(*shapes)

返回 NumPy 广播 shapes 后的形状。

broadcast_to_rank(x, rank)

添加前导维度 1 以使 x 的秩为 rank

broadcasted_iota(dtype, shape, dimension, *)

iota 的便捷封装。

cbrt(x[, accuracy])

逐元素的立方根: \(\sqrt[3]{x}\)

ceil(x)

逐元素的上取整: \(\left\lceil x \right\rceil\)

clamp(min, x, max)

逐元素的钳位。

clz(x)

逐元素的零前缀计数。

collapse(operand, start_dimension[, ...])

将数组的维度折叠成一个维度。

complex(x, y)

逐元素创建复数: \(x + jy\)

composite(decomposition, name[, version])

具有分解函数定义的语义的复合函数。

concatenate(operands, dimension)

沿 dimension 拼接一系列数组。

conj(x)

逐元素的复共轭函数: \(\overline{x}\)

conv(lhs, rhs, window_strides, padding[, ...])

围绕 conv_general_dilated 的便捷包装器。

convert_element_type(operand, new_dtype)

逐元素的类型转换。

conv_dimension_numbers(lhs_shape, rhs_shape, ...)

将卷积 dimension_numbers 转换为 ConvDimensionNumbers

conv_general_dilated(lhs, rhs, ...[, ...])

通用 n 维卷积运算符,带可选的膨胀。

conv_general_dilated_local(lhs, rhs, ...[, ...])

通用 n 维非共享卷积运算符,带可选的膨胀。

conv_general_dilated_patches(lhs, ...[, ...])

根据 conv_general_dilated 的感受野提取补丁。

conv_transpose(lhs, rhs, strides, padding[, ...])

用于计算 N 维卷积“转置”的便捷封装。

conv_with_general_padding(lhs, rhs, ...[, ...])

围绕 conv_general_dilated 的便捷包装器。

cos(x[, accuracy])

逐元素的余弦: \(\mathrm{cos}(x)\)

cosh(x)

逐元素的双曲余弦: \(\mathrm{cosh}(x)\)

cumlogsumexp(operand[, axis, reverse])

计算沿 axis 的累积 logsumexp。

cummax(operand[, axis, reverse])

计算沿 axis 的累积最大值。

cummin(operand[, axis, reverse])

计算沿 axis 的累积最小值。

cumprod(operand[, axis, reverse])

计算沿 axis 的累积积。

cumsum(operand[, axis, reverse])

计算沿 axis 的累积和。

digamma(x)

逐元素的双伽马函数: \(\psi(x)\)

div(x, y)

逐元素的除法: \(x \over y\)

dot(lhs, rhs[, precision, ...])

向量/向量、矩阵/向量和矩阵/矩阵乘法。

dot_general(lhs, rhs, dimension_numbers[, ...])

通用点积/收缩运算符。

dynamic_index_in_dim(operand, index[, axis, ...])

围绕 dynamic_slice 执行整数索引的便捷封装。

dynamic_slice(operand, start_indices, ...[, ...])

封装了 XLA 的 DynamicSlice 运算符。

dynamic_slice_in_dim(operand, start_index, ...)

应用于单个维度的 lax.dynamic_slice() 的便捷封装。

dynamic_update_index_in_dim(operand, update, ...)

围绕 dynamic_update_slice() 的便捷封装,用于在单个 axis 中更新大小为 1 的切片。

dynamic_update_slice(operand, update, ...[, ...])

封装了 XLA 的 DynamicUpdateSlice 运算符。

dynamic_update_slice_in_dim(operand, update, ...)

围绕 dynamic_update_slice() 的便捷封装,用于在单个 axis 中更新切片。

eq(x, y)

逐元素的相等: \(x = y\)

erf(x)

逐元素的误差函数: \(\mathrm{erf}(x)\)

erfc(x)

逐元素的余误差函数: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)

erf_inv(x)

逐元素的反误差函数: \(\mathrm{erf}^{-1}(x)\)

exp(x[, accuracy])

逐元素的指数: \(e^x\)

exp2(x[, accuracy])

逐元素的以 2 为底的指数: \(2^x\)

expand_dims(array, dimensions)

将任意数量的大小为 1 的维度插入到数组中。

expm1(x[, accuracy])

逐元素的 \(e^{x} - 1\)

fft(x, fft_type, fft_lengths)

floor(x)

逐元素的向下取整: \(\left\lfloor x \right\rfloor\)

full(shape, fill_value[, dtype, sharding])

返回一个填充了 fill_valueshape 数组。

full_like(x, fill_value[, dtype, shape, ...])

基于示例数组 x 创建一个类似于 np.full 的完整数组。

gather(operand, start_indices, ...[, ...])

收集运算符。

ge(x, y)

逐元素的大于或等于: \(x \geq y\)

gt(x, y)

逐元素的大于: \(x > y\)

igamma(a, x)

逐元素的正则不完全伽马函数。

igammac(a, x)

逐元素的余正则不完全伽马函数。

imag(x)

逐元素提取虚部: \(\mathrm{Im}(x)\)

index_in_dim(operand, index[, axis, keepdims])

围绕 lax.slice() 执行整数索引的便捷封装。

index_take(src, idxs, axes)

integer_pow(x, y)

逐元素的幂: \(x^y\),其中 \(y\) 是静态整数。

iota(dtype, size)

封装了 XLA 的 Iota 运算符。

is_finite(x)

逐元素的 \(\mathrm{isfinite}\)

le(x, y)

逐元素的小于或等于: \(x \leq y\)

lgamma(x)

逐元素的对数伽马: \(\mathrm{log}(\Gamma(x))\)

log(x[, accuracy])

逐元素的自然对数: \(\mathrm{log}(x)\)

log1p(x[, accuracy])

逐元素的 \(\mathrm{log}(1 + x)\)

logistic(x[, accuracy])

逐元素的逻辑(sigmoid)函数: \(\frac{1}{1 + e^{-x}}\)

lt(x, y)

逐元素的小于: \(x < y\)

max(x, y)

逐元素的最大值: \(\mathrm{max}(x, y)\)

min(x, y)

逐元素的最小值: \(\mathrm{min}(x, y)\)

mul(x, y)

逐元素的乘法: \(x \times y\)

ne(x, y)

逐元素的不等于: \(x \neq y\)

neg(x)

逐元素的取反: \(-x\)

nextafter(x1, x2)

返回 x1 之后朝向 x2 方向的下一个可表示值。

optimization_barrier(operand, /)

阻止编译器在障碍物之间移动操作。

pad(operand, padding_value, padding_config)

对数组应用低、高和/或内部填充。

platform_dependent(*args[, default])

分阶段地输出平台相关的代码。

polygamma(m, x)

逐多元伽马函数: \(\psi^{(m)}(x)\)

population_count(x)

逐元素的位计数,计算每个元素中设置的位数。

pow(x, y)

逐元素的幂: \(x^y\)

random_gamma_grad(*args)

real(x)

逐元素提取实部: \(\mathrm{Re}(x)\)

reciprocal(x)

逐元素的倒数: \(1 \over x\)

reduce(operands, init_values, computation, ...)

封装了 XLA 的 Reduce 运算符。

reduce_and(operand, axes)

计算一个或多个数组轴上元素的按位与。

reduce_max(operand, axes)

计算一个或多个数组轴上元素的最大值。

reduce_min(operand, axes)

计算一个或多个数组轴上元素的最小值。

reduce_or(operand, axes)

计算一个或多个数组轴上元素的按位或。

reduce_precision(operand, exponent_bits, ...)

封装了 XLA 的 ReducePrecision 运算符。

reduce_prod(operand, axes)

计算一个或多个数组轴上元素的积。

reduce_sum(operand, axes)

计算一个或多个数组轴上元素的和。

reduce_window(operand, init_value, ...[, ...])

填充窗口上的归约。

reduce_xor(operand, axes)

计算一个或多个数组轴上元素的按位异或。

rem(x, y)

逐元素的余数: \(x \bmod y\)

reshape(operand, new_sizes[, dimensions, ...])

封装了 XLA 的 Reshape 运算符。

rev(operand, dimensions)

封装了 XLA 的 Rev 运算符。

rng_bit_generator(key, shape[, dtype, ...])

无状态 PRNG 位生成器。

rng_uniform(a, b, shape)

有状态 PRNG 生成器。

round(x[, rounding_method])

逐元素的四舍五入。

rsqrt(x[, accuracy])

逐元素的平方根倒数: \(1 \over \sqrt{x}\)

scatter(operand, scatter_indices, updates, ...)

散布更新运算符。

scatter_add(operand, scatter_indices, ...[, ...])

散布加法运算符。

scatter_apply(operand, scatter_indices, ...)

散布应用运算符。

scatter_max(operand, scatter_indices, ...[, ...])

散布最大值运算符。

scatter_min(operand, scatter_indices, ...[, ...])

散布最小值运算符。

scatter_mul(operand, scatter_indices, ...[, ...])

散布乘法运算符。

shift_left(x, y)

逐元素的左移: \(x \ll y\)

shift_right_arithmetic(x, y)

逐元素的算术右移: \(x \gg y\)

shift_right_logical(x, y)

逐元素的逻辑右移: \(x \gg y\)

sign(x)

逐元素的符号。

sin(x[, accuracy])

逐元素的正弦: \(\mathrm{sin}(x)\)

sinh(x)

逐元素的双曲正弦: \(\mathrm{sinh}(x)\)

slice(operand, start_indices, limit_indices)

封装了 XLA 的 Slice 运算符。

slice_in_dim(operand, start_index, limit_index)

应用于单个维度的 lax.slice() 的便捷封装。

sort()

封装了 XLA 的 Sort 运算符。

sort_key_val(keys, values[, dimension, ...])

沿 dimensionkeys 进行排序,并对 values 应用相同的置换。

split(operand, sizes[, axis])

沿 axis 拆分数组。

sqrt(x[, accuracy])

逐元素的平方根: \(\sqrt{x}\)

square(x)

逐元素的平方: \(x^2\)

squeeze(array, dimensions)

从数组中挤压掉任意数量的大小为 1 的维度。

sub(x, y)

逐元素的减法: \(x - y\)

tan(x[, accuracy])

逐元素的正切: \(\mathrm{tan}(x)\)

tanh(x[, accuracy])

逐元素的双曲正切: \(\mathrm{tanh}(x)\)

top_k(operand, k)

返回 operand 沿最后一个轴的顶部 k 个值及其索引。

transpose(operand, permutation)

封装了 XLA 的 Transpose 运算符。

zeros_like_array(x)

zeta(x, q)

逐元素的赫尔维茨 zeta 函数: \(\zeta(x, q)\)

控制流运算符#

associative_scan(fn, elems[, reverse, axis])

并行地使用结合二元操作执行扫描。

cond(pred, true_fun, false_fun, *operands[, ...])

有条件地应用 true_funfalse_fun

fori_loop(lower, upper, body_fun, init_val, *)

通过归约为 jax.lax.while_loop(),从 lower 循环到 upper

map(f, xs, *[, batch_size])

将函数映射到前导数组轴上。

scan(f, init[, xs, length, reverse, unroll, ...])

在携带状态的同时,在前导数组轴上扫描函数。

select(pred, on_true, on_false)

根据布尔谓词在两个分支之间进行选择。

select_n(which, *cases)

从多个案例中选择数组值。

switch(index, branches, *operands[, operand])

应用由 index 给出的 branches 中的一个。

while_loop(cond_fun, body_fun, init_val)

cond_fun 为 True 时,循环重复调用 body_fun

自定义梯度运算符#

stop_gradient(x)

停止梯度计算。

custom_linear_solve(matvec, b, solve[, ...])

执行带有隐式定义梯度的无矩阵线性求解。

custom_root(f, initial_guess, solve, ...[, ...])

可微分地求解函数的根。

并行运算符#

all_gather(x, axis_name, *[, ...])

在所有副本中收集 x 的值。

all_to_all(x, axis_name, split_axis, ...[, ...])

具体化映射轴并映射不同的轴。

psum(x, axis_name, *[, axis_index_groups])

在 pmap 轴 axis_name 上计算 x 的全归约和。

psum_scatter(x, axis_name, *[, ...])

类似于 psum(x, axis_name),但每个设备只保留部分结果。

pmax(x, axis_name, *[, axis_index_groups])

在 pmap 轴 axis_name 上计算 x 的全归约最大值。

pmin(x, axis_name, *[, axis_index_groups])

在 pmap 轴 axis_name 上计算 x 的全归约最小值。

pmean(x, axis_name, *[, axis_index_groups])

在 pmap 轴 axis_name 上计算 x 的全归约平均值。

ppermute(x, axis_name, perm)

根据置换 perm 执行集体置换。

pshuffle(x, axis_name, perm)

jax.lax.ppermute 的便捷封装,采用替代置换编码。

pswapaxes(x, axis_name, axis, *[, ...])

将 pmap 轴 axis_name 与未映射轴 axis 交换。

axis_index(axis_name)

返回映射轴 axis_name 上的索引。

axis_size(axis_name)

返回映射轴 axis_name 的大小。

psend(x, axis_name, perm)

根据置换 perm 执行集体发送。

precv(token, out_shape, axis_name, perm)

根据置换 perm 执行集体接收。

线性代数运算符 (jax.lax.linalg)#

cholesky(x, *[, symmetrize_input])

Cholesky 分解。

cholesky_update(r_matrix, w_vector)

Cholesky 秩 1 更新。

eig(x, *[, compute_left_eigenvectors, ...])

一般矩阵的特征分解。

eigh(x, *[, lower, symmetrize_input, ...])

Hermitian 矩阵的特征分解。

hessenberg(a)

将方阵约化为上 Hessenberg 形式。

householder_product(a, taus)

基本 Householder 反射的乘积。

lu(x)

带部分选主元的 LU 分解。

lu_pivots_to_permutation(pivots, ...)

将 LU 返回的枢轴(行交换)转换为置换。

qdwh(x, *[, is_hermitian, max_iterations, ...])

用于极分解的基于 QR 的动态加权 Halley 迭代。

qr()

QR 分解。

schur(x, *[, compute_schur_vectors, ...])

Schur 分解。

svd()

奇异值分解。

SvdAlgorithm(value[, names, module, ...])

SVD 算法的枚举。

symmetric_product(a_matrix, c_matrix, *[, ...])

对称积。

triangular_solve(a, b, *[, left_side, ...])

三角求解。

tridiagonal(a, *[, lower])

将对称/ Hermitian 矩阵约化为三对角形式。

tridiagonal_solve(dl, d, du, b)

计算三对角线性系统的解。

参数类#

class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[source]#

描述卷积的批次、空间和特征维度。

参数:
  • lhs_spec (Sequence[int]) – 一个包含 (批次维度, 特征维度, 空间维度…) 的非负整数维度编号元组。

  • rhs_spec (Sequence[int]) – 一个包含 (输出特征维度, 输入特征维度, 空间维度…) 的非负整数维度编号元组。

  • out_spec (Sequence[int]) – 一个包含 (批次维度, 特征维度, 空间维度…) 的非负整数维度编号元组。

jax.lax.ConvGeneralDilatedDimensionNumbers#

alias of tuple[str, str, str] | ConvDimensionNumbers | None

class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#

指定用于计算点积的算法。

当用于指定 precision 输入到 dot()dot_general() 和其他点积函数时,此数据结构用于控制计算点积所用算法的属性。此 API 控制计算中使用的精度,并允许用户访问硬件特定的加速功能。

对这些算法的支持依赖于平台,使用不支持的算法在计算编译时会引发 Python 异常。已知在至少某些平台上支持的算法列在 DotAlgorithmPreset 枚举中,这些是尝试使用此 API 的良好起点。

“点算法”由以下参数指定:

  • lhs_precision_typerhs_precision_type,操作的左侧(LHS)和右侧(RHS)四舍五入到的数据类型。

  • accumulation_type 用于累积的数据类型。

  • lhs_component_countrhs_component_countnum_primitive_operations 适用于将操作的左侧(LHS)和/或右侧(RHS)分解为多个组件,并对这些值执行多次操作的算法,通常是为了模拟更高的精度。对于没有分解的算法,这些值应设置为 1

  • allow_imprecise_accumulation 用于指定某些步骤是否允许使用较低精度进行累积(例如 CUBLASLT_MATMUL_DESC_FAST_ACCUM)。

点操作的 StableHLO 规范 不要求精度类型与输入或输出的存储类型相同,但某些平台可能要求这些类型匹配。此外,如果指定了输入算法的 accumulation_type 参数,dot_general() 的返回类型总是由该参数定义。

示例

使用 32 位浮点累加器累积两个 16 位浮点数

>>> algorithm = DotAlgorithm(
...     lhs_precision_type=np.float16,
...     rhs_precision_type=np.float16,
...     accumulation_type=np.float32,
... )
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> dot(lhs, rhs, precision=algorithm)  
array([ 1.,  4.,  9., 16.], dtype=float16)

或者,等效地,使用预设

>>> algorithm = DotAlgorithmPreset.F16_F16_F32
>>> dot(lhs, rhs, precision=algorithm)  
array([ 1.,  4.,  9., 16.], dtype=float16)

预设也可以通过名称指定

>>> dot(lhs, rhs, precision="F16_F16_F32")  
array([ 1.,  4.,  9., 16.], dtype=float16)

参数 preferred_element_type 可用于在不降级累积类型的情况下返回输出。

>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32)  
array([ 1.,  4.,  9., 16.], dtype=float32)
参数:
  • lhs_precision_type (类DType类型)

  • rhs_precision_type (类DType类型)

  • accumulation_type (类DType类型)

  • lhs_component_count (int)

  • rhs_component_count (int)

  • num_primitive_operations (int)

  • allow_imprecise_accumulation (bool)

class jax.lax.DotAlgorithmPreset(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

一个用于计算点积的已知算法的枚举。

Enum 提供了一组命名的 DotAlgorithm 对象,已知这些对象至少在某个平台上受支持。有关这些算法行为的更多详细信息,请参阅 DotAlgorithm 文档。

通过将此 Enum 的成员或其名称字符串作为 precision 参数传递,可以在调用 dot()dot_general() 或大多数其他 JAX 点积函数时,从此列表中选择一个算法。

例如,用户可以直接使用此 Enum 指定预设

>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
>>> dot(lhs, rhs, precision=algorithm)  
array([ 1.,  4.,  9., 16.], dtype=float16)

或者,等效地,它们也可以通过名称指定

>>> dot(lhs, rhs, precision="F16_F16_F32")  
array([ 1.,  4.,  9., 16.], dtype=float16)

预设的名称通常是 LHS_RHS_ACCUM,其中 LHSRHS 分别是 lhsrhs 输入的元素类型,ACCUM 是累加器的元素类型。一些预设有一个额外的后缀,它们的含义在下面有说明。支持的预设包括:

DEFAULT = 1#

将根据输入和输出类型选择算法。

ANY_F8_ANY_F8_F32 = 2#

接受任意 float8 输入类型并累积到 float32。

ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#

ANY_F8_ANY_F8_F32 类似,但使用更快的累积,代价是降低精度。

ANY_F8_ANY_F8_ANY = 4#

ANY_F8_ANY_F8_F32 类似,但累积类型由 preferred_element_type 控制。

ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#

ANY_F8_ANY_F8_F32_FAST_ACCUM 类似,但累积类型由 preferred_element_type 控制。

F16_F16_F16 = 6#
F16_F16_F32 = 7#
BF16_BF16_BF16 = 8#
BF16_BF16_F32 = 9#
BF16_BF16_F32_X3 = 10#

后缀 _X3 表示算法使用 3 次操作来模拟更高的精度。

BF16_BF16_F32_X6 = 11#

BF16_BF16_F32_X3 类似,但使用 6 次操作而非 3 次。

BF16_BF16_F32_X9 = 12#

BF16_BF16_F32_X3 类似,但使用 9 次操作而非 3 次。

TF32_TF32_F32 = 13#
TF32_TF32_F32_X3 = 14#

后缀 _X3 表示算法使用 3 次操作来模拟更高的精度。

F32_F32_F32 = 15#
F64_F64_F64 = 16#
property supported_lhs_types: tuple[类DType类型, ...] | None[source]#
property supported_rhs_types: tuple[类DType类型, ...] | None[source]#
property accumulation_type: 类DType类型 | None[source]#
supported_output_types(lhs_dtype, rhs_dtype)[source]#
参数:
  • lhs_dtype (类DType类型)

  • rhs_dtype (类DType类型)

返回类型:

tuple[类DType类型, …] | None

class jax.lax.FftType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

描述要执行的 FFT 操作类型。

FFT = 0#

正向复数到复数 FFT。

IFFT = 1#

逆向复数到复数 FFT。

IRFFT = 3#

逆向实数到复数 FFT。

RFFT = 2#

正向实数到复数 FFT。

class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[source]#

描述 XLA Gather 算子的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。

参数:
  • offset_dims (tuple[int, ...]) – “gather”输出中偏移到从“operand”切片数组中的维度集合。必须是按升序排列的整数元组,每个整数代表输出的一个维度编号。

  • collapsed_slice_dims (tuple[int, ...]) – “operand”中维度 i 的集合,其中 slice_sizes[i] == 1 并且在 gather 的输出中不应有对应的维度。必须是按升序排列的整数元组。

  • start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出 operand 中要进行切片的相应维度。必须是一个整数元组,其大小等于 start_indices.shape[-1]

  • operand_batching_dims (tuple[int, ...]) – “operand”中批处理维度 i 的集合,其中 slice_sizes[i] == 1,并且在 start_indices(在 start_indices_batching_dims 中的相同索引处)和 gather 的输出中都应具有对应的维度。必须是按升序排列的整数元组。

  • start_indices_batching_dims (tuple[int, ...]) – “start_indices”中批处理维度 i 的集合,该集合在 operand(在 operand_batching_dims 中的相同索引处)和 gather 的输出中都应具有对应的维度。必须是整数元组(顺序根据与 operand_batching_dims 的对应关系固定)。

与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是存在一个索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加一个大小为 1 的末尾维度。

class jax.lax.GatherScatterMode(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

描述如何在 gather 或 scatter 操作中处理越界索引。

可能的值包括:

CLIP

索引将被限制到最近的在范围内值,即,使得要收集的整个窗口都在范围内。

FILL_OR_DROP

如果收集窗口的任何部分越界,则返回的整个窗口(即使是那些原本在范围内的元素)都将用常量填充。如果散射窗口的任何部分越界,则整个窗口将被丢弃。

PROMISE_IN_BOUNDS

用户承诺索引在界限内。将不会执行额外检查。实际上,在当前的 XLA 实现中,这意味着越界 gather 操作将被钳制,但越界 scatter 操作将被丢弃。如果索引越界,梯度将不正确。

class jax.lax.Precision(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

lax 矩阵乘法相关函数的精度枚举。

JAX 函数中与设备相关的 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和精度之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有效,并且不影响输入/输出数据类型。成员包括:

DEFAULT

最快模式,但精度最低。在 TPU 上:以 bfloat16 执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32(例如在 A100 和 H100 GPU 上),否则使用标准 float32(例如在 V100 GPU 上)。别名:'default''fastest'

HIGH

较慢但更精确。在 TPU 上:以 3 次 bfloat16 通道执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32,否则使用 float32。别名:'high'

HIGHEST

最慢但最精确。在 TPU 上:以 6 次 bfloat16 执行 float32 计算。别名:'highest'。在 GPU 上:使用 float32。

jax.lax.PrecisionLike#

别名 None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset

class jax.lax.RandomAlgorithm(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

描述用于 rng_bit_generator 的 PRNG 算法。

RNG_DEFAULT = 0#

平台的默认算法。

RNG_THREE_FRY = 1#

Threefry-2x32 PRNG 算法。

RNG_PHILOX = 2#

Philox-4x32 PRNG 算法。

class jax.lax.RoundingMethod(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

处理 jax.lax.round() 中半值(例如 0.5)的舍入策略。

AWAY_FROM_ZERO = 0#

将半值远离零舍入(例如,0.5 -> 1,-0.5 -> -1)。

TO_NEAREST_EVEN = 1#

将半值舍入到最近的偶数整数。这也被称为“银行家舍入法”(例如,0.5 -> 0,1.5 -> 2)。

class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#

描述 XLA Scatter 算子的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。

参数:
  • update_window_dims (序列[int]) – “updates”中作为窗口维度的维度集合。必须是按升序排列的整数元组,每个整数代表一个维度编号。

  • inserted_window_dims (序列[int]) – 必须插入到 updates 形状中的大小为 1 的窗口维度集合。必须是按升序排列的整数元组,每个整数代表输出的一个维度编号。它们是 gather 操作中 collapsed_slice_dims 的镜像。

  • scatter_dims_to_operand_dims (序列[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,其大小等于 scatter_indices.shape[-1]

  • operand_batching_dims (序列[int]) – “operand”中批处理维度 i 的集合,该集合在 scatter_indices(在 scatter_indices_batching_dims 中的相同索引处)和 updates 中都应具有对应的维度。必须是按升序排列的整数元组。它们是 gather 操作中 operand_batching_dims 的镜像。

  • scatter_indices_batching_dims (序列[int]) – “scatter_indices”中批处理维度 i 的集合,该集合在 operand(在 operand_batching_dims 中的相同索引处)和 gather 的输出中都应具有对应的维度。必须是整数元组(顺序根据与 input_batching_dims 的对应关系固定)。它们是 gather 操作中 start_indices_batching_dims 的镜像。

与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是存在一个索引向量维度,并且它必须始终是最后一个维度。要散射标量索引,请添加一个大小为 1 的末尾维度。