jax.numpy.compress

Contents

jax.numpy.compress#

jax.numpy.compress(condition, a, axis=None, out=None)[source]#

Return selected slices of an array along given axis.

LAX-backend implementation of numpy.compress().

Original docstring below.

When working along a given axis, a slice along that axis is returned in output for each index where condition evaluates to True. When working on a 1-D array, compress is equivalent to extract.

Parameters:
  • condition (1-D array of bools) – Array that selects which entries to return. If len(condition) is less than the size of a along the given axis, then output is truncated to the length of the condition array.

  • a (array_like) – Array from which to extract a part.

  • axis (int, optional) – Axis along which to take slices. If None (default), work on the flattened array.

  • out (None)

Returns:

compressed_array – A copy of a without the slices along axis for which condition is false.

Return type:

ndarray