jax.numpy.delete#

jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[source]#

Return a new array with sub-arrays along an axis deleted. For a one

LAX-backend implementation of numpy.delete().

delete() usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may specify assume_unique_indices=True to perform the operation in a manner that does not require static indices.

Original docstring below.

dimensional array, this returns those entries not returned by arr[obj].

Parameters:
  • arr (array_like) – Input array.

  • obj (slice, int or array of ints) –

    Indicate indices of sub-arrays to remove along the specified axis.

    Changed in version 1.19.0: Boolean indices are now treated as a mask of elements to remove, rather than being cast to the integers 0 and 1.

  • axis (int, optional) – The axis along which to delete the subarray defined by obj. If axis is None, obj is applied to the flattened array.

  • assume_unique_indices (int, optional (default=False)) – In case of array-like integer (not boolean) indices, assume the indices are unique, and perform the deletion in a way that is compatible with JIT and other JAX transformations.

Returns:

out – A copy of arr with the elements specified by obj removed. Note that delete does not occur in-place. If axis is None, out is a flattened array.

Return type:

ndarray