jax.ops.segment_sum#
- jax.ops.segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#
Computes the sum within segments of an array.
Similar to TensorFlow’s segment_sum
- Parameters:
data (
Any
) – an array with the values to be summed.segment_ids (
Any
) – an array with integer dtype that indicates the segments of data (along its leading axis) to be summed. Values can be repeated and need not be sorted.num_segments (
Optional
[int
]) – optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices insegment_ids
, calculated asmax(segment_ids) + 1
. Since num_segments determines the size of the output, a static value must be provided to usesegment_sum
in a JIT-compiled function.indices_are_sorted (
bool
) – whethersegment_ids
is known to be sorted.unique_indices (
bool
) – whether segment_ids is known to be free of duplicates.bucket_size (
Optional
[int
]) – size of bucket to group indices into.segment_sum
is performed on each bucket separately to improve numerical stability of addition. DefaultNone
means no bucketing.mode (
Optional
[GatherScatterMode
]) – ajax.lax.GatherScatterMode
value describing how out-of-bounds indices should be handled. By default, values outside of the range [0, num_segments) are dropped and do not contribute to the sum.
- Return type:
- Returns:
An array with shape
(num_segments,) + data.shape[1:]
representing the segment sums.
Examples
Simple 1D segment sum:
>>> data = jnp.arange(5) >>> segment_ids = jnp.array([0, 0, 1, 1, 2]) >>> segment_sum(data, segment_ids) Array([1, 5, 4], dtype=int32)
Using JIT requires static num_segments:
>>> from jax import jit >>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3) Array([1, 5, 4], dtype=int32)