jax.numpy.trace#
- jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#
计算给定轴上输入的对角线之和。
numpy.trace()
的 JAX 实现。- 参数:
a (ArrayLike) – 输入数组。必须满足
a.ndim >= 2
。offset (int | ArrayLike) – 可选,int,默认值=0。 距离主对角线的对角线偏移量。 可以是正数或负数。
axis1 (int) – 可选,默认值=0。要沿其取对角线之和的第一个轴。 必须是静态整数值。
axis2 (int) – 可选,默认值=1。要沿其取对角线之和的第二个轴。 必须是静态整数值。
dtype (DTypeLike | None) – 可选。输出数组的 dtype。 应该在 JIT 编译中作为静态参数提供。
out (None) – JAX 不使用。
- 返回:
维度为 x.ndim-2 的数组,包含沿轴 (axis1, axis2) 的对角线元素之和
- 返回类型:
另请参阅
jax.numpy.diag()
: 返回指定的对角线或构造对角线数组jax.numpy.diagonal()
: 返回数组的指定对角线。jax.numpy.diagflat()
: 返回一个二维数组,其中扁平化的输入数组排列在对角线上。
示例
>>> x = jnp.arange(1, 9).reshape(2, 2, 2) >>> x Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32) >>> jnp.trace(x) Array([ 8, 10], dtype=int32) >>> jnp.trace(x, offset=1) Array([3, 4], dtype=int32) >>> jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)