jax.dtypes.issubdtype#

jax.dtypes.issubdtype(a, b)[source]#

Returns True if first argument is a typecode lower/equal in type hierarchy.

This is like numpy.issubdtype(), but can handle dtype extensions such as jax.dtypes.bfloat16 and jax.dtypes.prng_key.

Parameters:
  • a (DTypeLike | ExtendedDType | None)

  • b (DTypeLike | ExtendedDType | None)

Return type:

bool