jax.numpy.swapaxes#
- jax.numpy.swapaxes(a, axis1, axis2)[源代码]#
交换数组的两个轴。
JAX 对
numpy.swapaxes()的实现,它基于jax.lax.transpose()实现。- 参数:
- 返回:
输入数组
a的副本,其中指定的轴已交换。- 返回类型:
注意事项
与
numpy.swapaxes()不同,jax.numpy.swapaxes()返回的是输入数组的副本而不是视图。然而,在 JIT 编译下,编译器会在可能的情况下优化掉这种复制,因此在实践中不会影响性能。另请参阅
jax.numpy.moveaxis(): 移动数组的单个轴。jax.numpy.rollaxis():moveaxis的旧 API。jax.lax.transpose(): 更通用的轴置换。jax.Array.swapaxes(): 通过数组方法实现相同的功能。
示例
>>> a = jnp.ones((2, 3, 4, 5)) >>> jnp.swapaxes(a, 1, 3).shape (2, 5, 4, 3)
通过 `swapaxes` 数组方法实现的等效输出
>>> a.swapaxes(1, 3).shape (2, 5, 4, 3)
通过
transpose()实现的等效输出>>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3)