jax.numpy.copy#
- jax.numpy.copy(a, order=None)[源代码]#
- 返回数组的副本。 - JAX 对 - numpy.copy()的实现。- 另请参阅 - jax.numpy.array():创建数组(可选择是否复制)。
- jax.Array.copy():作为数组方法访问的相同函数。
 - 示例 - 由于 JAX 数组是不可变的,在大多数情况下不需要显式的数组复制。一个例外是使用具有被捐赠参数的函数时(请参阅 - jax.jit()的- donate_argnums参数)。- >>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) >>> x = jnp.arange(4) >>> y = f(x) >>> print(y) [0 2 4 6] - 由于我们将 - x标记为被捐赠,因此原始数组不再可用- >>> print(x) Traceback (most recent call last): RuntimeError: Array has been deleted with shape=int32[4]. - 在这种情况下,显式复制将允许您保留对原始缓冲区的访问权限 - >>> x = jnp.arange(4) >>> y = f(x.copy()) >>> print(y) [0 2 4 6] >>> print(x) [0 1 2 3] 
