jax.numpy.roll#
- jax.numpy.roll(a, shift, axis=None)[source]#
沿指定轴滚动数组的元素。
numpy.roll()
的 JAX 实现。- 参数::
- 返回::
沿指定轴或轴滚动的
a
的副本。- 返回类型::
另请参阅
jax.numpy.rollaxis()
:将指定的轴滚动到给定位置。
示例
>>> a = jnp.array([0, 1, 2, 3, 4, 5]) >>> jnp.roll(a, 2) Array([4, 5, 0, 1, 2, 3], dtype=int32)
沿特定轴滚动元素
>>> a = jnp.array([[ 0, 1, 2, 3], ... [ 4, 5, 6, 7], ... [ 8, 9, 10, 11]]) >>> jnp.roll(a, 1, axis=0) Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32) >>> jnp.roll(a, [2, 3], axis=[0, 1]) Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)