jax.numpy.promote_types

Contents

jax.numpy.promote_types#

jax.numpy.promote_types(a, b)[source]#

Returns the type to which a binary operation should cast its arguments.

For details of JAX’s type promotion semantics, see Type promotion semantics.

Parameters:
  • a (jax.typing.DTypeLike) – a numpy.dtype or a dtype specifier.

  • b (jax.typing.DTypeLike) – a numpy.dtype or a dtype specifier.

Returns:

A numpy.dtype object.

Return type:

dtype