jax.numpy.nanargmin

Contents

jax.numpy.nanargmin#

jax.numpy.nanargmin(a, axis=None, out=None, keepdims=None)[source]#

Return the indices of the minimum values in the specified axis ignoring

LAX-backend implementation of numpy.nanargmin().

Warning: jax.numpy.argmin returns -1 for all-NaN slices and does not raise an error.

Original docstring below.

NaNs. For all-NaN slices ValueError is raised. Warning: the results cannot be trusted if a slice contains only NaNs and Infs.

Parameters:
  • a (array_like) – Input data.

  • axis (int, optional) – Axis along which to operate. By default flattened input is used.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

  • out (None)

Returns:

index_array – An array of indices or a single index value.

Return type:

ndarray