jax.lax.reduce_sum#
- jax.lax.reduce_sum(operand, axes, *, out_sharding=None)[源代码]#
计算一个或多个数组轴上元素的总和。
- 参数:
operand (ArrayLike) – 求和的数组。必须具有数值 dtype。
axes (Sequence[int]) – 指定求和轴的零个或多个唯一整数的序列。每个条目必须满足
0 <= axis < operand.ndim
。
- 返回:
一个与
operand
具有相同 dtype 的数组,其形状对应于operand.shape
中移除axes
后的维度。- 返回类型:
注意事项
与
jax.numpy.sum()
不同,jax.lax.reduce_sum()
不会上溯窄宽度类型进行累加,因此 8 位或 16 位类型的总和可能会出现舍入误差。另请参阅
jax.numpy.sum()
:更灵活的 NumPy 风格求和 API,围绕jax.lax.reduce_sum()
构建。其他低级
jax.lax
归约运算符:jax.lax.reduce_prod()
、jax.lax.reduce_max()
、jax.lax.reduce_min()
、jax.lax.reduce_and()
、jax.lax.reduce_or()
、jax.lax.reduce_xor()
。