jax.numpy.full#

jax.numpy.full(shape, fill_value, dtype=None, *, device=None)[source]#

创建一个填充了指定值的数组。

JAX 对 numpy.full() 的实现。

参数:
  • shape (Any) – 指定所创建数组形状的整数或整数序列。

  • fill_value (Array | ndarray | bool | number | bool | int | float | complex) – 用于填充所创建数组的标量或数组。

  • dtype (str | type[Any] | dtype | SupportsDType | None) – 所创建数组的可选数据类型;默认为填充值的数据类型。

  • device (Device | Sharding | None) – (可选) 所创建数组将被提交到的 DeviceSharding

返回:

具有指定形状和数据类型,并在指定设备(如果已指定)上的数组。

返回类型:

Array

示例

>>> jnp.full(4, 2, dtype=float)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full((2, 3), 0, dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)

fill_value 也可以是一个数组,其将被广播到指定的形状

>>> jnp.full((2, 3), fill_value=jnp.arange(3))
Array([[0, 1, 2],
       [0, 1, 2]], dtype=int32)