jax.numpy.nanargmax

jax.numpy.nanargmax(a, axis=None)[source]
Return the indices of the maximum values in the specified axis ignoring

NaNs. For all-NaN slices ValueError is raised. Warning: the results cannot be trusted if a slice contains only NaNs and -Infs.

LAX-backend implementation of nanargmax(). Warning: jax.numpy.argmax returns -1 for all-NaN slices and does not raise an error.

Original docstring below.

Returns
index_arrayndarray

An array of indices or a single index value.

argmax, nanargmin

>>> a = np.array([[np.nan, 4], [2, 3]])
>>> np.argmax(a)
0
>>> np.nanargmax(a)
1
>>> np.nanargmax(a, axis=0)
array([1, 0])
>>> np.nanargmax(a, axis=1)
array([1, 1])