jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[源代码]#
从 npy 文件加载 JAX 数组。
JAX 对
numpy.load()的封装。此函数是对
numpy.load()的简单封装,但在使用numpy.save()或jax.numpy.save()创建的.npy文件时,输出将返回为jax.Array,并且bfloat16数据类型将得到恢复。对于.npz文件,结果将以普通 NumPy 数组的形式返回。此函数需要具体的数组输入,并且与
jax.jit()或jax.vmap()等转换不兼容。- 参数:
file (IO[bytes] | str | os.PathLike[Any]) – 包含数组数据的字符串、字节串或类路径对象。
args (Any) – 有关其他参数,请参阅
numpy.load()kwargs (Any) – 有关其他参数,请参阅
numpy.load()
- 返回:
文件中存储的数组。
- 返回类型:
另请参阅
jax.numpy.save(): 将数组保存到文件。
示例
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)