jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[源代码]#
从 npy 文件加载 JAX 数组。
NumPy
numpy.load()
的 JAX 包装器。此函数是
numpy.load()
的简单包装器,但在使用numpy.save()
或jax.numpy.save()
创建的.npy
文件的情况下,输出将作为jax.Array
返回,并且bfloat16
数据类型将被恢复。对于.npz
文件,结果将作为正常的 NumPy 数组返回。此函数需要具体数组输入,并且与
jax.jit()
或jax.vmap()
等转换不兼容。- 参数:
file (IO[bytes] | str | os.PathLike[Any]) – 字符串、字节或包含数组数据的类路径对象。
args (Any) – 有关其他参数,请参阅
numpy.load()
kwargs (Any) – 有关其他参数,请参阅
numpy.load()
- 返回:
文件中存储的数组。
- 返回类型:
参见
jax.numpy.save()
: 将数组保存到文件。
示例
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)