jax.numpy.from_dlpack#

jax.numpy.from_dlpack(x, /, *, device=None, copy=None)[source]#

通过 DLPack 构造 JAX 数组。

JAX 对 numpy.from_dlpack() 的实现。

参数:
返回:

输入缓冲区的 JAX 数组。

返回类型:

数组

注意

虽然 JAX 数组始终是不可变的,但 DLPack 缓冲区不能被标记为不可变,并且 JAX 外部的进程可能会原地修改它们。如果一个 JAX 数组是从 DLPack 缓冲区构造而未进行复制,并且源缓冲区随后被原地修改,那么在使用关联的 JAX 数组时可能会导致未定义行为。

示例

通过 DLPack 在 NumPy 和 JAX 之间传递数据

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> x_numpy = rng.random(4, dtype='float32')
>>> print(x_numpy)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_numpy, "__dlpack__")  # NumPy supports the DLPack interface
True
>>> import jax.numpy as jnp
>>> x_jax = jnp.from_dlpack(x_numpy)
>>> print(x_jax)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_jax, "__dlpack__")  # JAX supports the DLPack interface
True
>>> x_numpy_round_trip = np.from_dlpack(x_jax)
>>> print(x_numpy_round_trip)
[0.08925092 0.773956   0.6545715  0.43887842]