jax.numpy.tri#
- jax.numpy.tri(N, M=None, k=0, dtype=None)[source]#
返回一个数组,其对角线和对角线下方为 1,其他位置为 0。
numpy.tri()
的 JAX 实现- 参数:
N (int) – int. 返回数组的行维度。
M (int | None) – 可选, int. 返回数组的列维度。 如果未指定,则
M = N
。k (int) – 可选, int, 默认=0. 指定数组填充 1 的子对角线及其下方的位置。
k=0
指主对角线,k<0
指主对角线下方的子对角线,k>0
指主对角线上方的子对角线。dtype (DTypeLike | None) – 可选,返回数组的数据类型。 默认类型是浮点型。
- 返回:
一个形状为
(N, M)
的数组,包含下三角,其中由k
指定的子对角线以下的元素设置为 1,其他位置为 0。- 返回类型:
另请参阅
jax.numpy.tril()
: 返回数组的下三角。jax.numpy.triu()
: 返回数组的上三角。
示例
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
当
M
不等于N
时>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
当
k>0
时>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
当
k<0
时>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)