jax.numpy.empty#

jax.numpy.empty(shape, dtype=None, *, device=None, out_sharding=None)[source]#

创建一个空数组。

JAX 对 numpy.empty() 的实现。由于 XLA 无法创建未初始化的数组,jax.numpy.empty() 将始终返回一个全零数组。

参数:
  • shape (Any) – 指定所创建数组形状的整数或整数序列。

  • dtype (str | type[Any] | dtype | SupportsDType | None) – 所创建数组的可选数据类型;默认值为 float32 或 float64,具体取决于 X64 配置(参见默认数据类型和 X64 标志)。

  • device (Device | Sharding | None) – (可选)所创建数组将提交到的 DeviceSharding。此参数的存在是为了与 Python 数组 API 标准兼容。

  • out_sharding (NamedSharding | PartitionSpec | None) – (可选)表示所创建数组分片方式的 PartitionSpecNamedSharding(有关详细信息,请参见显式分片)。此参数的存在是为了与 JAX 中其他数组创建例程保持一致。同时指定 out_shardingdevice 将导致错误。

返回:

指定形状和数据类型(如果指定了设备/分片)的数组。

返回类型:

数组

示例

>>> jnp.empty(4)
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.empty((2, 3), dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)