jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[源代码]#

创建一个方阵或矩形单位矩阵

JAX 对 numpy.eye() 的实现。

参数:
  • N (DimSize) – 指定数组第一维度的整数。

  • M (DimSize | None) – 可选整数,指定数组的第二维度;默认为 N 的相同值。

  • k (int | ArrayLike) – 可选整数,指定对角线的偏移量。正值用于上对角线,负值用于下对角线。默认值为零。

  • dtype (DTypeLike | None) – 可选数据类型;默认为浮点类型。

  • device (xc.Device | Sharding | None) – 可选的 DeviceSharding,创建的数组将提交到该设备或分片。

返回:

形状为 (N, M) 的单位数组,如果未指定 M 则为 (N, N)

返回类型:

数组

另请参阅

jax.numpy.identity():用于生成方阵单位矩阵的更简单 API。

示例

一个简单的 3x3 单位矩阵

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

带有偏移对角线的整数单位矩阵

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

非方阵单位矩阵

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)