jax.lax.reduce_sum#

jax.lax.reduce_sum(operand, axes, *, out_sharding=None)[源代码]#

计算一个或多个数组轴上元素的总和。

参数:
  • operand (ArrayLike) – 求和的数组。必须具有数值 dtype。

  • axes (Sequence[int]) – 指定求和轴的零个或多个唯一整数的序列。每个条目必须满足 0 <= axis < operand.ndim

返回:

一个与 operand 具有相同 dtype 的数组,其形状对应于 operand.shape 中移除 axes 后的维度。

返回类型:

Array

注意事项

jax.numpy.sum() 不同,jax.lax.reduce_sum() 不会上溯窄宽度类型进行累加,因此 8 位或 16 位类型的总和可能会出现舍入误差。

另请参阅