jax.numpy.nanargmax#

jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[source]#

Return the indices of the maximum values in the specified axis ignoring

LAX-backend implementation of numpy.nanargmax().

Warning: jax.numpy.argmax returns -1 for all-NaN slices and does not raise an error.

Original docstring below.

NaNs. For all-NaN slices ValueError is raised. Warning: the results cannot be trusted if a slice contains only NaNs and -Infs.

Parameters
  • a (array_like) – Input data.

  • axis (int, optional) – Axis along which to operate. By default flattened input is used.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

  • out (Optional[Any]) –

Returns

index_array – An array of indices or a single index value.

Return type

ndarray