jax.ops.segment_max#
- jax.ops.segment_max(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[源代码]#
计算数组分段的最大值。
类似于 TensorFlow 的 segment_max
- 参数:
data (ArrayLike) – 要缩减值的数组。
segment_ids (ArrayLike) – 一个整数 dtype 的数组,指示要缩减的 data 的段(沿其前导轴)。值可以重复,不需要排序。
num_segments (int | None) – 可选参数,一个非负整数,表示分段的数量。默认值为支持
segment_ids中所有索引的最小分段数,计算方式为max(segment_ids) + 1。由于 num_segments 决定了输出的大小,因此在使用 JIT 编译函数中的segment_max时必须提供一个静态值。indices_are_sorted (bool) –
segment_ids是否已知已排序。unique_indices (bool) – segment_ids 是否已知不包含重复项。
bucket_size (int | None) – 用于将索引分组的桶的大小。
segment_max在每个桶中单独执行。默认值None表示不分桶。mode (slicing.GatherScatterMode | str | None) – 一个
jax.lax.GatherScatterMode值,描述了如何处理越界索引。默认情况下,范围 [0, num_segments) 之外的值将被丢弃,并且不会对结果做出贡献。
- 返回:
一个形状为
(num_segments,) + data.shape[1:]的数组,表示分段的最大值。- 返回类型:
示例
简单的 1D 分段最大值
>>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_max(data, segment_ids) Array([1, 3, 5], dtype=int32)
使用 JIT 需要静态 num_segments
>>> from jax import jit >>> jit(segment_max, static_argnums=2)(data, segment_ids, 3) Array([1, 3, 5], dtype=int32)