jax.numpy.empty#
- jax.numpy.empty(shape, dtype=None, *, device=None, out_sharding=None)[来源]#
创建一个空数组。
JAX 实现
numpy.empty()。由于 XLA 无法创建未初始化的数组,jax.numpy.empty()将始终返回一个全零数组。- 参数:
shape (Any) – 指定创建数组形状的整数或整数序列。
dtype (str | type[Any] | dtype | SupportsDType | None) – 创建数组的可选 dtype;默认为 float32 或 float64,具体取决于 X64 配置(参见 默认 dtype 和 X64 标志)。
device (Device | Sharding | None) – (可选)创建的数组将被提交到的
Device或Sharding。此参数是为了与 Python Array API 标准 兼容而存在的。out_sharding (NamedSharding | PartitionSpec | None) – (可选)
PartitionSpec或NamedSharding,表示创建数组的分片(有关更多详细信息,请参阅 显式分片)。此参数的存在是为了与其他 JAX 中的数组创建例程保持一致。同时指定out_sharding和device将导致错误。
- 返回:
指定形状和数据类型(如果指定了设备/分片)的数组。
- 返回类型:
示例
>>> jnp.empty(4) Array([0., 0., 0., 0.], dtype=float32) >>> jnp.empty((2, 3), dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool)