jax.numpy.frombuffer#

jax.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)[source]#

将缓冲区转换为 1-D JAX 数组。

numpy.frombuffer() 的 JAX 实现。

参数:
  • buffer (bytes | Any) – 包含数据的对象。它必须是长度为 dtype 元素大小整数倍的 bytes 对象,或者必须是导出 Python 缓冲区接口的对象。

  • dtype (DTypeLike) – 可选。数组的期望数据类型。默认为 float64。这指定了用于解析缓冲区的 dtype,但请注意,如果 jax_enable_x64 标志设置为 False,则解析后,64 位值将强制转换为 32 位 JAX 数组。

  • count (int) – 可选整数,指定要从缓冲区读取的项目数。如果为 -1(默认值),则读取缓冲区中的所有项目。

  • offset (int) – 可选整数,指定在缓冲区开头要跳过的字节数。默认为 0。

返回:

一个 1-D JAX 数组,表示来自缓冲区的已解释数据。

返回类型:

Array

参见

示例

使用字节缓冲区

>>> buf = b"\x00\x01\x02\x03\x04"
>>> jnp.frombuffer(buf, dtype=jnp.uint8)
Array([0, 1, 2, 3, 4], dtype=uint8)
>>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
Array([1, 2, 3, 4], dtype=uint8)

通过 Python 缓冲区接口构造 JAX 数组,使用 Python 内置的 array 模块。

>>> from array import array
>>> pybuffer = array('i', [0, 1, 2, 3, 4])
>>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
Array([0, 1, 2, 3, 4], dtype=int32)