jax.dtypes.issubdtype

Contents

jax.dtypes.issubdtype#

jax.dtypes.issubdtype(a, b)[source]#

Returns True if first argument is a typecode lower/equal in type hierarchy.

This is like numpy.issubdtype(), but can handle dtype extensions such as jax.dtypes.bfloat16 and jax.dtypes.prng_key.

Parameters:
  • a (DTypeLike | None)

  • b (DTypeLike | None)

Return type:

bool