jax.numpy.repeat#
- jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None, out_sharding=None)[source]#
从重复元素构造数组。
numpy.repeat()
的 JAX 实现。- 参数:
a (ArrayLike) – N 维数组
repeats (ArrayLike) – 1D 整数数组,指定重复次数。 必须与重复轴的长度匹配。
axis (int | None) – 整数,指定沿着其构建重复数组的
a
的轴。 如果为 None(默认),则首先展平a
。total_repeat_length (int | None) – 对于
jnp.repeat
与jit()
和其他 JAX 转换兼容,必须静态指定此值。 如果sum(repeats)
大于指定的total_repeat_length
,则将丢弃剩余的值。 如果sum(repeats)
小于total_repeat_length
,则将重复最终值。out_sharding (NamedSharding | P | None)
- 返回:
由
a
的重复值构建的数组。- 返回类型:
另请参阅
jax.numpy.tile()
:重复整个数组而不是单个值。
示例
沿最后一个轴将每个值重复两次
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果未指定
axis
,则将展平输入数组>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
将数组传递给
repeats
以将每个值重复不同的次数>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
为了在
jit
和其他 JAX 转换中使用repeat
,必须使用total_repeat_length
静态指定输出的大小>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
如果 total_repeat_length 小于
sum(repeats)
,则结果将被截断>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果它更大,则将使用最终值填充其他条目
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)