jax.numpy.copy#

jax.numpy.copy(a, order=None)[source]#

返回数组的副本。

JAX 对 numpy.copy() 的实现。

参数:
  • a (ArrayLike) – 要复制的类数组对象

  • order (str | None | None) – 在 JAX 中未实现

返回值:

输入数组 a 的副本。

返回类型:

Array

另请参阅

示例

由于 JAX 数组是不可变的,因此在大多数情况下,显式数组副本不是必需的。一个例外是当使用带有捐赠参数的函数时(请参阅 donate_argnums 参数到 jax.jit())。

>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0)
>>> x = jnp.arange(4)
>>> y = f(x)
>>> print(y)
[0 2 4 6]

因为我们将 x 标记为被捐赠,所以原始数组不再可用

>>> print(x)  
Traceback (most recent call last):
RuntimeError: Array has been deleted with shape=int32[4].

在类似这样的情况下,显式副本将使您可以保留对原始缓冲区的访问权限

>>> x = jnp.arange(4)
>>> y = f(x.copy())
>>> print(y)
[0 2 4 6]
>>> print(x)
[0 1 2 3]