jax.lax.argminΒΆ

jax.lax.argmin(operand, axis, index_dtype)[source]ΒΆ

Computes the index of the minimum element along axis.

Parameters
  • operand (Any) –

  • axis (int) –

  • index_dtype (Any) –

Return type

Tuple[Any, Any]