jax.numpy.repeat

Contents

jax.numpy.repeat#

jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[source]#

Construct an array from repeated elements.

JAX implementation of numpy.repeat().

Parameters:
  • a (ArrayLike) – N-dimensional array

  • repeats (ArrayLike) – 1D integer array specifying the number of repeats. Must match the length of the repeated axis.

  • axis (int | None | None) – integer specifying the axis of a along which to construct the repeated array. If None (default) then a is first flattened.

  • total_repeat_length (int | None | None) – this must be specified statically for jnp.repeat to be compatible with jit() and other JAX transformations. If sum(repeats) is larger than the specified total_repeat_length, the remaining values will be discarded. If sum(repeats) is smaller than total_repeat_length, the final value will be repeated.

Returns:

an array constructed from repeated values of a.

Return type:

Array

See also

Examples

Repeat each value twice along the last axis:

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

If axis is not specified, the input array will be flattened:

>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

Pass an array to repeats to repeat each value a different number of times:

>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

In order to use repeat within jit and other JAX transformations, the size of the output must be specified statically using total_repeat_length:

>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

If total_repeat_length is smaller than sum(repeats), the result will be truncated:

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

If it is larger, then the additional entries will be filled with the final value:

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)