jax.numpy.swapaxes#

jax.numpy.swapaxes(a, axis1, axis2)[源代码]#

交换数组的两个轴。

JAX 对 numpy.swapaxes() 的实现,它基于 jax.lax.transpose() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis1 (int) – 第一个轴的索引

  • axis2 (int) – 第二个轴的索引

返回:

输入数组 a 的副本,其中指定的轴已交换。

返回类型:

Array

注意事项

numpy.swapaxes() 不同,jax.numpy.swapaxes() 返回的是输入数组的副本而不是视图。然而,在 JIT 编译下,编译器会在可能的情况下优化掉这种复制,因此在实践中不会影响性能。

另请参阅

示例

>>> a = jnp.ones((2, 3, 4, 5))
>>> jnp.swapaxes(a, 1, 3).shape
(2, 5, 4, 3)

通过 `swapaxes` 数组方法实现的等效输出

>>> a.swapaxes(1, 3).shape
(2, 5, 4, 3)

通过 transpose() 实现的等效输出

>>> a.transpose(0, 3, 2, 1).shape
(2, 5, 4, 3)