jax.numpy.nanargmaxΒΆ

jax.numpy.nanargmax(a, axis=None)[source]ΒΆ

Return the indices of the maximum values in the specified axis ignoring NaNs. For all-NaN slices ValueError is raised. Warning: the results cannot be trusted if a slice contains only NaNs and -Infs.

LAX-backend implementation of nanargmax(). Warning: jax.numpy.argmax returns -1 for all-NaN slices and does not raise an error.

Original docstring below.

Parameters
  • a (array_like) – Input data.

  • axis (int, optional) – Axis along which to operate. By default flattened input is used.

Returns

index_array – An array of indices or a single index value.

Return type

ndarray

See also

argmax(), nanargmin()

Examples

>>> a = np.array([[np.nan, 4], [2, 3]])
>>> np.argmax(a)
0
>>> np.nanargmax(a)
1
>>> np.nanargmax(a, axis=0)
array([1, 0])
>>> np.nanargmax(a, axis=1)
array([1, 1])