jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[源]#

创建单位方阵

JAX 对 numpy.identity() 的实现。

参数:
  • n (DimSize) – 指定数组每个维度的整数大小。

  • dtype (DTypeLike | None) – 可选的数据类型;默认为浮点数。

返回:

形状为 (n, n) 的单位数组。

返回类型:

Array

另请参阅

jax.numpy.eye(): 非方阵和/或偏移单位矩阵。

示例

一个简单的 3x3 单位矩阵

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

一个 2x2 整型单位矩阵

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)