jax.numpy.load#

jax.numpy.load(file, *args, **kwargs)[源代码]#

从 npy 文件加载 JAX 数组。

JAX 对 numpy.load() 的封装。

此函数是对 numpy.load() 的简单封装,但在使用 numpy.save()jax.numpy.save() 创建的 .npy 文件时,输出将返回为 jax.Array,并且 bfloat16 数据类型将得到恢复。对于 .npz 文件,结果将以普通 NumPy 数组的形式返回。

此函数需要具体的数组输入,并且与 jax.jit()jax.vmap() 等转换不兼容。

参数:
  • file (IO[bytes] | str | os.PathLike[Any]) – 包含数组数据的字符串、字节串或类路径对象。

  • args (Any) – 有关其他参数,请参阅 numpy.load()

  • kwargs (Any) – 有关其他参数,请参阅 numpy.load()

返回:

文件中存储的数组。

返回类型:

Array

另请参阅

示例

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)