jax.numpy.triu#
- jax.numpy.triu(m, k=0)[源码]#
返回数组的上三角部分。
numpy.triu()
的 JAX 实现- 参数:
m (ArrayLike) – 输入数组。必须具有
m.ndim >= 2
。k (int) – 可选,int,默认值为 0。指定数组元素的下对角线,低于该对角线的元素将设置为零。
k=0
指主对角线,k<0
指主对角线下方的下对角线,k>0
指主对角线上方的上对角线。
- 返回:
与输入数组形状相同的数组,包含给定数组的上三角部分,其中由
k
指定的子对角线以下的元素设置为零。- 返回类型:
另请参阅
jax.numpy.tril()
:返回数组的下三角部分。jax.numpy.tri()
:返回一个数组,对角线及其下方为 1,其他地方为零。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9], ... [10, 11, 12]]) >>> jnp.triu(x) Array([[1, 2, 3], [0, 5, 6], [0, 0, 9], [0, 0, 0]], dtype=int32) >>> jnp.triu(x, k=1) Array([[0, 2, 3], [0, 0, 6], [0, 0, 0], [0, 0, 0]], dtype=int32) >>> jnp.triu(x, k=-1) Array([[ 1, 2, 3], [ 4, 5, 6], [ 0, 8, 9], [ 0, 0, 12]], dtype=int32)
当
m.ndim > 2
时,jnp.triu
在尾部轴上按批处理方式运算。>>> x1 = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.triu(x1) Array([[[1, 2], [0, 4]], [[5, 6], [0, 8]]], dtype=int32)