jax.numpy.trace#

jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[源码]#

计算输入数组沿给定轴对角的元素之和。

JAX 对 numpy.trace() 的实现。

参数:
  • a (ArrayLike) – 输入数组。必须满足 a.ndim >= 2

  • offset (int | ArrayLike) – 可选,整数,默认为 0。主对角线之外的对角线偏移量。可以是正数或负数。

  • axis1 (int) – 可选,默认为 0。用于计算对角线之和的第一个轴。必须是静态整数值。

  • axis2 (int) – 可选,默认为 1。用于计算对角线之和的第二个轴。必须是静态整数值。

  • dtype (DTypeLike | None) – 可选。输出数组的数据类型。在 JIT 编译时应作为静态参数提供。

  • out (None) – JAX 不使用。

返回:

一个维度为 x.ndim-2 的数组,其中包含沿轴 (axis1, axis2) 的对角线元素之和。

返回类型:

Array

另请参阅

示例

>>> x = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.trace(x)
Array([ 8, 10], dtype=int32)
>>> jnp.trace(x, offset=1)
Array([3, 4], dtype=int32)
>>> jnp.trace(x, axis1=1, axis2=2)
Array([ 5, 13], dtype=int32)
>>> jnp.trace(x, offset=1, axis1=1, axis2=2)
Array([2, 6], dtype=int32)