jax.numpy.argminΒΆ

jax.numpy.argmin(a, axis=None, out=None)[source]ΒΆ

Returns the indices of the minimum values along an axis.

LAX-backend implementation of argmin().

Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.

Returns

index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

Return type

ndarray of ints