jax.ops.segment_prod#
- jax.ops.segment_prod(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#
Computes the product within segments of an array.
Similar to TensorFlow’s segment_prod
- Parameters:
data (
Any
) – an array with the values to be reduced.segment_ids (
Any
) – an array with integer dtype that indicates the segments of data (along its leading axis) to be reduced. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result.num_segments (
Optional
[int
]) – optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices insegment_ids
, calculated asmax(segment_ids) + 1
. Since num_segments determines the size of the output, a static value must be provided to usesegment_prod
in a JIT-compiled function.indices_are_sorted (
bool
) – whethersegment_ids
is known to be sorted.unique_indices (
bool
) – whether segment_ids is known to be free of duplicates.bucket_size (
Optional
[int
]) – size of bucket to group indices into.segment_prod
is performed on each bucket separately to improve numerical stability of addition. DefaultNone
means no bucketing.mode (
Optional
[GatherScatterMode
]) – ajax.lax.GatherScatterMode
value describing how out-of-bounds indices should be handled. By default, values outside of the range [0, num_segments) are dropped and do not contribute to the sum.
- Return type:
- Returns:
An array with shape
(num_segments,) + data.shape[1:]
representing the segment products.
Examples
Simple 1D segment product:
>>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_prod(data, segment_ids) Array([ 0, 6, 20], dtype=int32)
Using JIT requires static num_segments:
>>> from jax import jit >>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3) Array([ 0, 6, 20], dtype=int32)