jax.numpy.load#

jax.numpy.load(file, *args, **kwargs)[source]#

从 npy 文件加载 JAX 数组。

JAX 对 numpy.load() 的封装。

此函数是 numpy.load() 的一个简单包装器。但如果文件是使用 numpy.save()jax.numpy.save() 创建的 .npy 文件,则输出将作为 jax.Array 返回,并且 bfloat16 数据类型将被恢复。对于 .npz 文件,结果将作为普通的 NumPy 数组返回。

此函数需要具体的数组输入,并且不兼容 jax.jit()jax.vmap() 等转换。

参数:
  • file (IO[bytes] | str | os.PathLike[Any]) – 包含数组数据的字符串、字节或路径类对象。

  • args (Any) – 有关其他参数,请参阅 numpy.load()

  • kwargs (Any) – 有关其他参数,请参阅 numpy.load()

返回:

文件中存储的数组。

返回类型:

Array

另请参阅

示例

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)