jax.numpy.ones#

jax.numpy.ones(shape, dtype=None, *, device=None, out_sharding=None)[源代码]#

创建一个充满 1 的数组。

numpy.ones() 的 JAX 实现。

参数:
  • shape (Any) – 指定所创建数组形状的整数或整数序列。

  • dtype (str | type[Any] | dtype | SupportsDType | None) – 为创建的数组选择的可选 dtype;默认为 float32 或 float64,具体取决于 X64 配置(参见 默认 dtypes 和 X64 标志)。

  • device (Device | Sharding | None) – (可选) DeviceSharding,创建的数组将提交到该设备或分片。 此参数的存在是为了与 Python Array API 标准 兼容。

  • out_sharding (NamedSharding | PartitionSpec | None) – (可选) 代表已创建数组分片的 PartitionSpecNamedSharding (有关更多详细信息,请参阅 显式分片)。 此参数的存在是为了与其他 JAX 数组创建例程保持一致。 指定 out_shardingdevice 都会导致错误。

返回:

指定形状和数据类型(如果指定了设备/分片)的数组。

返回类型:

数组

示例

>>> jnp.ones(4)
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True,  True,  True],
       [ True,  True,  True]], dtype=bool)