jax.numpy.repeat#
- jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[source]#
Construct an array from repeated elements.
JAX implementation of
numpy.repeat()
.- Parameters:
a (ArrayLike) – N-dimensional array
repeats (ArrayLike) – 1D integer array specifying the number of repeats. Must match the length of the repeated axis.
axis (int | None | None) – integer specifying the axis of
a
along which to construct the repeated array. If None (default) thena
is first flattened.total_repeat_length (int | None | None) – this must be specified statically for
jnp.repeat
to be compatible withjit()
and other JAX transformations. Ifsum(repeats)
is larger than the specifiedtotal_repeat_length
, the remaining values will be discarded. Ifsum(repeats)
is smaller thantotal_repeat_length
, the final value will be repeated.
- Returns:
an array constructed from repeated values of
a
.- Return type:
See also
jax.numpy.tile()
: repeat a full array rather than individual values.
Examples
Repeat each value twice along the last axis:
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If
axis
is not specified, the input array will be flattened:>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
Pass an array to
repeats
to repeat each value a different number of times:>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
In order to use
repeat
withinjit
and other JAX transformations, the size of the output must be specified statically usingtotal_repeat_length
:>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
If total_repeat_length is smaller than
sum(repeats)
, the result will be truncated:>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If it is larger, then the additional entries will be filled with the final value:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)