jax.numpy.swapaxes#
- jax.numpy.swapaxes(a, axis1, axis2)[来源]#
交换数组的两个轴。
numpy.swapaxes()
的 JAX 实现,使用jax.lax.transpose()
实现。注意事项
与
numpy.swapaxes()
不同,jax.numpy.swapaxes()
将返回输入数组的副本,而不是视图。 但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此实际上不会对性能产生影响。另请参阅
jax.numpy.moveaxis()
: 移动数组的单个轴。jax.numpy.rollaxis()
:moveaxis
的旧 API。jax.lax.transpose()
: 更通用的轴置换。jax.Array.swapaxes()
: 通过数组方法实现的相同功能。
示例
>>> a = jnp.ones((2, 3, 4, 5)) >>> jnp.swapaxes(a, 1, 3).shape (2, 5, 4, 3)
通过
swapaxes
数组方法等效输出>>> a.swapaxes(1, 3).shape (2, 5, 4, 3)
通过
transpose()
等效输出>>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3)