jax.ops.segment_prod#
- jax.ops.segment_prod(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[源代码]#
计算数组分段内的乘积。
类似于 TensorFlow 的 segment_prod
- 参数:
data (ArrayLike) – 要缩减的值的数组。
segment_ids (ArrayLike) – 整数类型数组,指示要缩减的 data 的段(沿其前导轴)。值可以重复,无需排序。
num_segments (int | None | None) – 可选,一个非负整数值,指示段的数量。默认设置为支持
segment_ids
中所有索引的最小段数,计算为max(segment_ids) + 1
。由于 num_segments 确定输出的大小,因此必须提供静态值才能在 JIT 编译的函数中使用segment_prod
。indices_are_sorted (bool) –
segment_ids
是否已知已排序。unique_indices (bool) – segment_ids 是否已知没有重复项。
bucket_size (int | None | None) – 将索引分组到桶中的桶大小。
segment_prod
在每个桶上单独执行,以提高数值稳定性。默认值None
表示不进行分桶。mode (lax.GatherScatterMode | None | None) –
jax.lax.GatherScatterMode
值,描述应如何处理越界索引。默认情况下,范围 [0, num_segments) 之外的值将被丢弃,并且不计入结果。
- 返回:
形状为
(num_segments,) + data.shape[1:]
的数组,表示分段乘积。- 返回类型:
示例
简单的一维分段乘积
>>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_prod(data, segment_ids) Array([ 0, 6, 20], dtype=int32)
使用 JIT 需要静态 num_segments
>>> from jax import jit >>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3) Array([ 0, 6, 20], dtype=int32)