jax.numpy.nanargmin(a, axis=None)[source]ΒΆ

Return the indices of the minimum values in the specified axis ignoring

LAX-backend implementation of nanargmin().

Warning: jax.numpy.argmin returns -1 for all-NaN slices and does not raise an error.

Original docstring below.

NaNs. For all-NaN slices ValueError is raised. Warning: the results cannot be trusted if a slice contains only NaNs and Infs.

  • a (array_like) – Input data.

  • axis (int, optional) – Axis along which to operate. By default flattened input is used.


index_array – An array of indices or a single index value.

Return type