jax.lax.argmin

Contents

jax.lax.argmin#

jax.lax.argmin(operand, axis, index_dtype)[source]#

Computes the index of the minimum element along axis.

Parameters:
  • operand (jax.typing.ArrayLike)

  • axis (int)

  • index_dtype (jax.typing.DTypeLike)

Return type:

Array