jax.numpy.copy#
- jax.numpy.copy(a, order=None)[source]#
返回数组的副本。
JAX 对
numpy.copy()
的实现。另请参阅
jax.numpy.array()
: 创建带或不带副本的数组。jax.Array.copy()
: 作为数组方法访问的相同函数。
示例
由于 JAX 数组是不可变的,因此在大多数情况下,显式数组副本不是必需的。一个例外是当使用带有捐赠参数的函数时(请参阅
donate_argnums
参数到jax.jit()
)。>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) >>> x = jnp.arange(4) >>> y = f(x) >>> print(y) [0 2 4 6]
因为我们将
x
标记为被捐赠,所以原始数组不再可用>>> print(x) Traceback (most recent call last): RuntimeError: Array has been deleted with shape=int32[4].
在类似这样的情况下,显式副本将使您可以保留对原始缓冲区的访问权限
>>> x = jnp.arange(4) >>> y = f(x.copy()) >>> print(y) [0 2 4 6] >>> print(x) [0 1 2 3]