jax.numpy.reshape#
- jax.numpy.reshape(a, shape, order='C', *, copy=None)[source]#
返回数组的重塑副本。
JAX 实现的
numpy.reshape()
,根据jax.lax.reshape()
实现。- 参数:
- 返回:
具有指定形状的输入数组的重塑副本。
- 返回类型:
Notes
与
numpy.reshape()
不同,jax.numpy.reshape()
将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此这在实践中不会对性能产生影响。参见
jax.Array.reshape()
:通过数组方法实现的等效功能。jax.numpy.ravel()
:将数组展平为 1D 形状。jax.numpy.squeeze()
:从数组的形状中删除一个或多个长度为 1 的轴。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
您可以使用
-1
自动计算与输入大小一致的形状>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
reshape 中轴的默认顺序是 C 风格的行优先顺序。要使用 Fortran 风格的列优先顺序,请指定
order='F'
>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
为方便起见,此功能也可以通过
jax.Array.reshape()
方法使用>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)