jax.numpy.delete

Contents

jax.numpy.delete#

jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[source]#

Delete entry or entries from an array.

JAX implementation of numpy.delete().

Parameters:
  • arr (ArrayLike) – array from which entries will be deleted.

  • obj (ArrayLike | slice) – index, indices, or slice to be deleted.

  • axis (int | None) – axis along which entries will be deleted.

  • assume_unique_indices (bool) – In case of array-like integer (not boolean) indices, assume the indices are unique, and perform the deletion in a way that is compatible with JIT and other JAX transformations.

Returns:

Copy of arr with specified indices deleted.

Return type:

Array

Note

delete() usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may specify assume_unique_indices=True to perform the operation in a manner that does not require static indices.

Examples

Delete entries from a 1D array:

>>> a = jnp.array([4, 5, 6, 7, 8, 9])
>>> jnp.delete(a, 2)
Array([4, 5, 7, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(1, 4))  # delete a[1:4]
Array([4, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(None, None, 2))  # delete a[::2]
Array([5, 7, 9], dtype=int32)

Delete entries from a 2D array along a specified axis:

>>> a2 = jnp.array([[4, 5, 6],
...                 [7, 8, 9]])
>>> jnp.delete(a2, 1, axis=1)
Array([[4, 6],
       [7, 9]], dtype=int32)

Delete multiple entries via a sequence of indices:

>>> indices = jnp.array([0, 1, 3])
>>> jnp.delete(a, indices)
Array([6, 8, 9], dtype=int32)

This will fail under jit() and other transformations, because the output shape cannot be known with the possibility of duplicate indices:

>>> jax.jit(jnp.delete)(a, indices)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].

If you can ensure that the indices are unique, pass assume_unique_indices to allow this to be executed under JIT:

>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
>>> jit_delete(a, indices, assume_unique_indices=True)
Array([6, 8, 9], dtype=int32)