jax.numpy.fill_diagonal#
- jax.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)[源代码]#
返回一个对角线被覆盖的数组副本。
JAX 对
numpy.fill_diagonal()的实现。numpy.fill_diagonal()的语义是原地修改数组,这对于 JAX 的不可变数组是不可能的。JAX 版本返回输入数组的修改副本,并添加了inplace参数,用户必须将其设置为 False`,以提醒用户此 API 差异。- 参数:
- 返回:
一个
a的副本,其中对角线被设置为val。- 返回类型:
示例
>>> x = jnp.zeros((3, 3), dtype=int) >>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32)
与
numpy.fill_diagonal()不同,输入x不会被修改。如果对角线值包含过多的条目,则会被截断
>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False) Array([[100, 0, 0], [ 0, 101, 0], [ 0, 0, 102]], dtype=int32)
如果对角线包含过少的条目,则会被重复
>>> x = jnp.zeros((4, 4), dtype=int) >>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False) Array([[3, 0, 0, 0], [0, 4, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)
对于非方形数组,会填充主导的方形切片的对角线
>>> x = jnp.zeros((3, 5), dtype=int) >>> jnp.fill_diagonal(x, 1, inplace=False) Array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]], dtype=int32)
对于方形 N 维数组,会填充 N 维对角线
>>> y = jnp.zeros((2, 2, 2)) >>> jnp.fill_diagonal(y, 1, inplace=False) Array([[[1., 0.], [0., 0.]], [[0., 0.], [0., 1.]]], dtype=float32)