jax.numpy.rollaxis#
- jax.numpy.rollaxis(a, axis, start=0)[源代码]#
将指定的轴滚动到给定位置。
JAX 实现
numpy.rollaxis()。此函数用于与 NumPy 兼容,但在大多数情况下,建议使用较新的
jax.numpy.moveaxis(),因为其参数的含义更直观。- 参数:
- 返回:
滚动轴后的
a的副本。- 返回类型:
注意事项
与
numpy.rollaxis()不同,jax.numpy.rollaxis()返回的是输入数组的副本而不是视图。但是,在 JIT 下,编译器会尽可能优化掉这些副本,因此这在实践中不会影响性能。另请参阅
jax.numpy.moveaxis():较新的 API,语义比rollaxis更清晰;在大多数情况下应优先使用此函数而非rollaxis。jax.numpy.swapaxes():交换两个轴。jax.numpy.transpose():对轴进行通用置换。
示例
>>> a = jnp.ones((2, 3, 4, 5))
将轴 2 滚动到数组的开头
>>> jnp.rollaxis(a, 2).shape (4, 2, 3, 5)
将轴 1 滚动到数组的末尾
>>> jnp.rollaxis(a, 1, a.ndim).shape (2, 4, 5, 3)
使用
moveaxis()实现的等效操作>>> jnp.moveaxis(a, 2, 0).shape (4, 2, 3, 5) >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)