jax.numpy.argmin#

jax.numpy.argmin(a, axis=None, out=None, keepdims=None)[source]#

Return the index of the minimum value of an array.

JAX implementation of numpy.argmin().

Parameters:
  • a (ArrayLike) – input array

  • axis (int | None | None) – optional integer specifying the axis along which to find the minimum value. If axis is not specified, a will be flattened.

  • out (None | None) – unused by JAX

  • keepdims (bool | None | None) – if True, then return an array with the same number of dimensions as a.

Returns:

an array containing the index of the minimum value along the specified axis.

Return type:

Array

See also

Examples

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmin(x)
Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmin(x, axis=1)
Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True)
Array([[0],
       [2]], dtype=int32)