jax.numpy.frombuffer#
- jax.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)[源代码]#
将缓冲区转换为一维 JAX 数组。
numpy.frombuffer()
的 JAX 实现。- 参数:
buffer (bytes | Any) – 包含数据的对象。它必须是一个长度为 dtype 元素大小的整数倍的 bytes 对象,或者必须是一个导出 Python 缓冲区接口 的对象。
dtype (DTypeLike) – 可选。数组的所需数据类型。默认值为
float64
。这指定了用于解析缓冲区的 dtype,但请注意,在解析后,如果jax_enable_x64
标志设置为False
,则 64 位值将被强制转换为 32 位 JAX 数组。count (int) – 可选整数,指定要从缓冲区读取的项目数。如果为 -1(默认值),则读取缓冲区中的所有项目。
offset (int) – 可选整数,指定在缓冲区开头跳过的字节数。默认为 0。
- 返回值:
一个一维 JAX 数组,表示从缓冲区解释的数据。
- 返回类型:
另请参阅
jax.numpy.fromstring()
:将文本字符串转换为一维 JAX 数组。
示例
使用字节缓冲区
>>> buf = b"\x00\x01\x02\x03\x04" >>> jnp.frombuffer(buf, dtype=jnp.uint8) Array([0, 1, 2, 3, 4], dtype=uint8) >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) Array([1, 2, 3, 4], dtype=uint8)
通过 Python 缓冲区接口,使用 Python 内置的
array
模块构造 JAX 数组。>>> from array import array >>> pybuffer = array('i', [0, 1, 2, 3, 4]) >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) Array([0, 1, 2, 3, 4], dtype=int32)