jax.ops.segment_min#
- jax.ops.segment_min(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#
Computes the minimum within segments of an array.
Similar to TensorFlow’s segment_min
- Parameters:
data (ArrayLike) – an array with the values to be reduced.
segment_ids (ArrayLike) – an array with integer dtype that indicates the segments of data (along its leading axis) to be reduced. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result.
num_segments (int | None | None) – optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in
segment_ids
, calculated asmax(segment_ids) + 1
. Since num_segments determines the size of the output, a static value must be provided to usesegment_min
in a JIT-compiled function.indices_are_sorted (bool) – whether
segment_ids
is known to be sorted.unique_indices (bool) – whether segment_ids is known to be free of duplicates.
bucket_size (int | None | None) – size of bucket to group indices into.
segment_min
is performed on each bucket separately. DefaultNone
means no bucketing.mode (lax.GatherScatterMode | None | None) – a
jax.lax.GatherScatterMode
value describing how out-of-bounds indices should be handled. By default, values outside of the range [0, num_segments) are dropped and do not contribute to the sum.
- Returns:
An array with shape
(num_segments,) + data.shape[1:]
representing the segment minimums.- Return type:
Examples
Simple 1D segment min:
>>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_min(data, segment_ids) Array([0, 2, 4], dtype=int32)
Using JIT requires static num_segments:
>>> from jax import jit >>> jit(segment_min, static_argnums=2)(data, segment_ids, 3) Array([0, 2, 4], dtype=int32)