jax.numpy.diagflat#

jax.numpy.diagflat(v, k=0)[source]#

返回一个二维数组,其中扁平化的输入数组沿对角线排列。

JAX 对 numpy.diagflat() 的实现。

对于 v 的某些标量值,这与 np.diagflat 不同。JAX 始终返回二维数组,而 NumPy 可能会根据 v 的类型返回标量。

参数:
  • v (ArrayLike) – 输入数组。可以是 N 维的,但会被扁平化为 1 维。

  • k (int) – 可选,默认值=0。对角线偏移量。正值将对角线放置在主对角线上方,负值将其放置在主对角线下方。

返回值:

一个二维数组,其输入元素沿对角线放置,并具有指定的偏移量 (k)。剩余的条目用零填充。

返回类型:

Array

示例

>>> jnp.diagflat(jnp.array([1, 2, 3]))
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)
>>> jnp.diagflat(jnp.array([1, 2, 3]), k=1)
Array([[0, 1, 0, 0],
       [0, 0, 2, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 0]], dtype=int32)
>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.diagflat(a)
Array([[1, 0, 0, 0],
       [0, 2, 0, 0],
       [0, 0, 3, 0],
       [0, 0, 0, 4]], dtype=int32)