jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[源代码]#

创建一个方阵单位矩阵

JAX 对 numpy.identity() 的实现。

参数:
  • n (DimSize) – 指定每个数组维度大小的整数。

  • dtype (DTypeLike | None | None) – 可选数据类型;默认为浮点数。

返回:

形状为 (n, n) 的单位矩阵。

返回类型:

数组

另请参阅

jax.numpy.eye():非方阵和/或偏移单位矩阵。

示例

一个简单的 3x3 单位矩阵

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

一个 2x2 整数单位矩阵

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)