jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[source]#
返回N维数组的转置版本。
JAX 对
numpy.transpose()
的实现,通过jax.lax.transpose()
实现。- 参数:
a (ArrayLike) – 输入数组
axes (Sequence[int] | None) – 可选择使用长度为 a.ndim 的整数序列
i
来指定置换,其中0 <= i < a.ndim
。默认为range(a.ndim)[::-1]
,即反转所有轴的顺序。
- 返回:
数组的转置副本。
- 返回类型:
另请参阅
jax.Array.transpose()
:通过Array
方法实现的等效函数。jax.Array.T
:通过Array
属性实现的等效函数。jax.numpy.matrix_transpose()
:转置数组的最后两个轴。这适用于处理批处理的二维矩阵。jax.numpy.swapaxes()
:交换数组中的任意两个轴。jax.numpy.moveaxis()
:将数组中的一个轴移动到另一个位置。
注意
与
numpy.transpose()
不同,jax.numpy.transpose()
将返回输入数组的副本而非视图。然而,在JIT模式下,编译器在可能的情况下会优化掉这些副本,因此这在实际中不会影响性能。示例
对于一维数组,转置是恒等变换
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
对于二维数组,转置是矩阵转置
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
对于N维数组,转置反转轴的顺序
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
可以指定
axes
参数来改变此默认行为>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
由于交换最后两个轴是常见操作,因此可以通过其自己的API来完成,即
jax.numpy.matrix_transpose()
>>> jnp.matrix_transpose(x).shape (3, 5, 4)
为了方便起见,转置也可以使用
jax.Array.transpose()
方法或jax.Array.T
属性来执行>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)