jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[source]#

沿指定轴滚动数组的元素。

numpy.roll() 的 JAX 实现。

参数::
  • a (ArrayLike) – 输入数组。

  • shift (ArrayLike | Sequence[int]) – 沿指定轴移动的位置数。如果是一个整数,则所有轴都移动相同的量。如果是一个元组,则每个轴的移动量分别指定。

  • axis (int | Sequence[int] | None | None) – 要滚动的轴或轴。如果为 None,则数组将被展平、移动,然后重塑为其原始形状。

返回::

沿指定轴或轴滚动的 a 的副本。

返回类型::

Array

另请参阅

示例

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

沿特定轴滚动元素

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)