jax.numpy.trace#
- jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[源码]#
计算输入数组沿给定轴对角的元素之和。
JAX 对
numpy.trace()的实现。- 参数:
- 返回:
一个维度为 x.ndim-2 的数组,其中包含沿轴 (axis1, axis2) 的对角线元素之和。
- 返回类型:
另请参阅
jax.numpy.diag(): 返回指定的对角线或构造一个对角线数组。jax.numpy.diagonal(): 返回数组的指定对角线。jax.numpy.diagflat(): 返回一个 2D 数组,其中展平的输入数组位于对角线上。
示例
>>> x = jnp.arange(1, 9).reshape(2, 2, 2) >>> x Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32) >>> jnp.trace(x) Array([ 8, 10], dtype=int32) >>> jnp.trace(x, offset=1) Array([3, 4], dtype=int32) >>> jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)