jax.numpy.moveaxis#
- jax.numpy.moveaxis(a, source, destination)[源代码]#
将数组轴移动到新位置
JAX 实现
numpy.moveaxis(),通过jax.lax.transpose()实现。- 参数:
- 返回:
输入数组
a的副本,其中轴已从source移动到destination。- 返回类型:
注意事项
与
numpy.moveaxis()不同,jax.numpy.moveaxis()将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化此类副本,因此这在实践中不会影响性能。另请参阅
jax.numpy.swapaxes():交换两个轴。jax.numpy.rollaxis():用于移动轴的旧 API。jax.numpy.transpose():通用轴排列。
示例
>>> a = jnp.ones((2, 3, 4, 5))
将轴
1移动到数组末尾>>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
将最后一个轴移动到位置 1
>>> jnp.moveaxis(a, -1, 1).shape (2, 5, 3, 4)
移动多个轴
>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape (4, 5, 3, 2)
也可以通过
transpose()实现此功能>>> a.transpose(2, 3, 1, 0).shape (4, 5, 3, 2)