jax.numpy.insert

Contents

jax.numpy.insert#

jax.numpy.insert(arr, obj, values, axis=None)[source]#

Insert values along the given axis before the given indices.

LAX-backend implementation of numpy.insert().

Original docstring below.

Parameters:
  • arr (array_like) – Input array.

  • obj (int, slice or sequence of ints) – Object that defines the index or indices before which values is inserted.

  • values (array_like) – Values to insert into arr. If the type of values is different from that of arr, values is converted to the type of arr. values should be shaped so that arr[...,obj,...] = values is legal.

  • axis (int, optional) – Axis along which to insert values. If axis is None then arr is flattened first.

Returns:

out – A copy of arr with values inserted. Note that insert does not occur in-place: a new array is returned. If axis is None, out is a flattened array.

Return type:

ndarray