jax.numpy.trim_zeros

Contents

jax.numpy.trim_zeros#

jax.numpy.trim_zeros(filt, trim='fb')[source]#

Trim leading and/or trailing zeros of the input array.

JAX implementation of numpy.trim_zeros().

Parameters:
  • filt (ArrayLike) – input array. Must have filt.ndim == 1.

  • trim (str) –

    string, optional, default = fb. Specifies from which end the input is trimmed.

    • f - trims only the leading zeros.

    • b - trims only the trailing zeros.

    • fb - trims both leading and trailing zeros.

Returns:

An array containing the trimmed input with same dtype as filt.

Return type:

Array

Examples

>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0])
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)