# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Array type functions.
#
# JAX dtypes differ from NumPy in both:
# a) their type promotion rules, and
# b) the set of supported types (e.g., bfloat16),
# so we need our own implementation that deviates from NumPy in places.
from __future__ import annotations
import abc
import builtins
import functools
from typing import cast, overload, Any, Literal, Union
import warnings
import ml_dtypes
import numpy as np
from jax._src import config
from jax._src.typing import DType, DTypeLike
from jax._src.util import set_module, StrictABC
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
try:
_ml_dtypes_version = tuple(map(int, ml_dtypes.__version__.split('.')[:3]))
except:
pass
else:
if _ml_dtypes_version < (0, 2, 0):
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
f"installed version is {ml_dtypes.__version__}.")
export = set_module('jax.dtypes')
@export
class extended(np.generic):
"""Scalar class for extended dtypes.
This is an abstract class that should never be instantiated, but rather
exists for the sake of `jnp.issubdtype`.
Examples:
>>> from jax import random
>>> from jax import dtypes
>>> key = random.key(0)
>>> jnp.issubdtype(key.dtype, dtypes.extended)
True
"""
@export
class prng_key(extended):
"""Scalar class for PRNG Key dtypes.
This is an abstract class that should never be instantiated, but rather
exists for the sake of `jnp.issubdtype`.
Examples:
>>> from jax import random
>>> from jax import dtypes
>>> key = random.key(0)
>>> jnp.issubdtype(key.dtype, dtypes.prng_key)
True
"""
class ExtendedDType(StrictABC):
"""Abstract Base Class for extended dtypes"""
@property
@abc.abstractmethod
def type(self) -> type: ...
# fp8 support
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
def supports_inf(dtype: DTypeLike) -> bool:
"""Return true if the dtype supports infinity, else return False."""
typ = np.dtype(dtype).type
if typ in {float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, float8_e5m2fnuz}:
return False
return issubdtype(dtype, np.inexact)
# bfloat16 support
bfloat16: type[np.generic] = ml_dtypes.bfloat16
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
_custom_float_scalar_types = [
float8_e4m3b11fnuz,
float8_e4m3fn,
float8_e4m3fnuz,
float8_e5m2,
float8_e5m2fnuz,
bfloat16,
]
_custom_float_dtypes = [
_float8_e4m3b11fnuz_dtype,
_float8_e4m3fn_dtype,
_float8_e4m3fnuz_dtype,
_float8_e5m2_dtype,
_float8_e5m2fnuz_dtype,
_bfloat16_dtype,
]
_float8_dtypes = [
_float8_e4m3b11fnuz_dtype,
_float8_e4m3fn_dtype,
_float8_e4m3fnuz_dtype,
_float8_e5m2_dtype,
_float8_e5m2fnuz_dtype,
]
# 4-bit integer support
int4: type[np.generic] = ml_dtypes.int4
uint4: type[np.generic] = ml_dtypes.uint4
_int4_dtype: np.dtype = np.dtype(int4)
_uint4_dtype: np.dtype = np.dtype(uint4)
_int4_dtypes = [
_int4_dtype,
_uint4_dtype,
]
# Default types.
bool_ = np.bool_
int_: type[Any]
uint: type[Any]
float_: type[Any]
complex_: type[Any]
if config.default_dtype_bits.value == '32':
int_ = np.int32
uint = np.uint32
float_ = np.float32
complex_ = np.complex64
else:
int_ = np.int64
uint = np.uint64
float_ = np.float64
complex_ = np.complex128
_default_types: dict[str, type[Any]] = {
'b': bool_,
'i': int_,
'u': uint,
'f': float_,
'c': complex_,
}
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)])
_dtype_to_32bit_dtype: dict[DType, DType] = {
np.dtype('int64'): np.dtype('int32'),
np.dtype('uint64'): np.dtype('uint32'),
np.dtype('float64'): np.dtype('float32'),
np.dtype('complex128'): np.dtype('complex64'),
}
# Note: we promote narrow types to float32 here for backward compatibility
# with earlier approaches. We might consider revisiting this, or perhaps
# tying the logic more closely to the type promotion lattice.
_dtype_to_inexact: dict[DType, DType] = {
np.dtype(k): np.dtype(v) for k, v in [
('bool', 'float32'),
('uint8', 'float32'), ('int8', 'float32'),
('uint16', 'float32'), ('int16', 'float32'),
('uint32', 'float32'), ('int32', 'float32'),
('uint64', 'float64'), ('int64', 'float64')
]
}
def to_numeric_dtype(dtype: DTypeLike) -> DType:
"""Promotes a dtype into an numeric dtype, if it is not already one."""
dtype_ = np.dtype(dtype)
return np.dtype('int32') if dtype_ == np.dtype('bool') else dtype_
def to_inexact_dtype(dtype: DTypeLike) -> DType:
"""Promotes a dtype into an inexact dtype, if it is not already one."""
dtype_ = np.dtype(dtype)
return _dtype_to_inexact.get(dtype_, dtype_)
def to_complex_dtype(dtype: DTypeLike) -> DType:
ftype = to_inexact_dtype(dtype)
if ftype in [np.dtype('float64'), np.dtype('complex128')]:
return np.dtype('complex128')
return np.dtype('complex64')
@functools.cache
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> DType | ExtendedDType:
if issubdtype(dtype, extended):
if not allow_extended_dtype:
raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} "
"with allow_extended_dtype=False")
return dtype
try:
dtype_ = np.dtype(dtype)
except TypeError as e:
raise TypeError(f'dtype {dtype!r} not understood') from e
if x64_enabled:
return dtype_
else:
return _dtype_to_32bit_dtype.get(dtype_, dtype_)
@overload
def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False) -> DType: ...
@overload
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType: ...
@export
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType:
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # pytype: disable=bad-return-type
# Default dtypes corresponding to Python scalars.
python_scalar_dtypes : dict[type, DType] = {
bool: np.dtype('bool'),
int: np.dtype('int64'),
float: np.dtype('float64'),
complex: np.dtype('complex128'),
}
@export
def scalar_type_of(x: Any) -> type:
"""Return the scalar type associated with a JAX value."""
typ = dtype(x)
if typ in _custom_float_dtypes:
return float
elif typ in _int4_dtypes:
return int
elif np.issubdtype(typ, np.bool_):
return bool
elif np.issubdtype(typ, np.integer):
return int
elif np.issubdtype(typ, np.floating):
return float
elif np.issubdtype(typ, np.complexfloating):
return complex
else:
raise TypeError(f"Invalid scalar value {x}")
def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType:
"""Return the numpy dtype for the given scalar type.
Raises
------
OverflowError: if `typ` is `int` and the value is too large for int64.
Examples
--------
>>> _scalar_type_to_dtype(int)
dtype('int32')
>>> _scalar_type_to_dtype(float)
dtype('float32')
>>> _scalar_type_to_dtype(complex)
dtype('complex64')
>>> _scalar_type_to_dtype(int)
dtype('int32')
>>> _scalar_type_to_dtype(int, 0)
dtype('int32')
>>> _scalar_type_to_dtype(int, 1 << 63) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
OverflowError: Python int 9223372036854775808 too large to convert to int32
"""
dtype = canonicalize_dtype(python_scalar_dtypes[typ])
if typ is int and value is not None:
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
return dtype
def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray:
"""Coerces a scalar or NumPy array to an np.array.
Handles Python scalar type promotion according to JAX's rules, not NumPy's
rules.
"""
if dtype is None and type(x) in python_scalar_dtypes:
dtype = _scalar_type_to_dtype(type(x), x)
return np.asarray(x, dtype)
iinfo = ml_dtypes.iinfo
finfo = ml_dtypes.finfo
def _issubclass(a: Any, b: Any) -> bool:
"""Determines if ``a`` is a subclass of ``b``.
Similar to issubclass, but returns False instead of an exception if `a` is not
a class.
"""
try:
return issubclass(a, b)
except TypeError:
return False
# TODO(jakevdp): consider whether to disallow None here. We allow it
# because np.issubdtype allows it (and treats it as equivalent to float64).
[docs]
def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool:
"""Returns True if first argument is a typecode lower/equal in type hierarchy.
This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as
:obj:`jax.dtypes.bfloat16` and `jax.dtypes.prng_key`.
"""
# Main departures from np.issubdtype are:
# - "extended" dtypes (like prng key types) are not normal numpy dtypes, so we
# need to handle them specifically. However, their scalar types do conform to
# the numpy scalar type hierarchy.
# - custom dtypes (like bfloat16, int4, etc.) are normal numpy dtypes, but they
# don't conform to the standard numpy type hierarchy (e.g. the bfloat16 scalar
# type is not a subclass of np.floating) so we must also handle these specially.
# First handle extended dtypes. This is important for performance because
# isinstance(x, extended) is called frequently within JAX internals.
a_is_type = isinstance(a, type)
b_is_type = isinstance(b, type)
if b_is_type and _issubclass(b, extended):
if isinstance(a, ExtendedDType):
return _issubclass(a.type, b)
if a_is_type and _issubclass(a, np.generic):
return _issubclass(a, b)
return _issubclass(np.dtype(a).type, b)
if isinstance(b, ExtendedDType):
return isinstance(a, ExtendedDType) and a == b
if isinstance(a, ExtendedDType):
a = a.type
a_is_type = isinstance(a, type)
# For all others, normalize inputs to scalar types.
a_sctype = a if a_is_type and _issubclass(a, np.generic) else np.dtype(a).type
b_sctype = b if b_is_type and _issubclass(b, np.generic) else np.dtype(b).type
# Now do special handling of custom float and int types, as they don't conform
# to the normal scalar type hierarchy.
if a_sctype in _custom_float_scalar_types:
return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic}
if a_sctype == int4:
return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic}
if a_sctype == uint4:
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
# Otherwise, fall back to numpy.issubdtype
return np.issubdtype(a_sctype, b_sctype)
can_cast = np.can_cast
JAXType = Union[type, DType]
# Enumeration of all valid JAX types in order.
_weak_types: list[JAXType] = [int, float, complex]
_bool_types: list[JAXType] = [np.dtype(bool)]
_signed_types: list[JAXType]
_unsigned_types: list[JAXType]
_int_types: list[JAXType]
_unsigned_types = [
np.dtype(uint4),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype(int4),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
_int_types = _unsigned_types + _signed_types
_float_types: list[JAXType] = [
*_custom_float_dtypes,
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
]
_complex_types: list[JAXType] = [
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}
_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
}
[docs]
def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bool:
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
Args:
dtype : the input dtype
kind : the data type kind.
If ``kind`` is dtype-like, return ``dtype = kind``.
If ``kind`` is a string, then return True if the dtype is in the specified category:
- ``'bool'``: ``{bool}``
- ``'signed integer'``: ``{int4, int8, int16, int32, int64}``
- ``'unsigned integer'``: ``{uint4, uint8, uint16, uint32, uint64}``
- ``'integral'``: shorthand for ``('signed integer', 'unsigned integer')``
- ``'real floating'``: ``{float8_*, float16, bfloat16, float32, float64}``
- ``'complex floating'``: ``{complex64, complex128}``
- ``'numeric'``: shorthand for ``('integral', 'real floating', 'complex floating')``
If ``kind`` is a tuple, then return True if dtype matches any entry of the tuple.
Returns:
True or False
"""
the_dtype = np.dtype(dtype)
kind_tuple: tuple[DType | str, ...] = kind if isinstance(kind, tuple) else (kind,)
options: set[DType] = set()
for kind in kind_tuple:
if isinstance(kind, str):
if kind not in _dtype_kinds:
raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}")
options.update(_dtype_kinds[kind])
elif isinstance(kind, np.dtype):
options.add(kind)
else:
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
return the_dtype in options
def _jax_type(dtype: DType, weak_type: bool) -> JAXType:
"""Return the jax type for a dtype and weak type."""
if weak_type:
if dtype == bool:
return dtype
if dtype in _custom_float_dtypes:
return float
return type(dtype.type(0).item())
return dtype
def _dtype_and_weaktype(value: Any) -> tuple[DType, bool]:
"""Return a (dtype, weak_type) tuple for the given input."""
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, list[JAXType]]:
"""
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, = _bool_types
_uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types
*f1_types, bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
if jax_numpy_dtype_promotion == 'standard':
out: dict[JAXType, list[JAXType]]
out = {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [*f1_types, bf, f2, c_],
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}
if _int4_dtype is not None:
out[i_].append(_int4_dtype)
out[_int4_dtype] = []
if _uint4_dtype is not None:
out[i_].append(_uint4_dtype)
out[_uint4_dtype] = []
return out
elif jax_numpy_dtype_promotion == 'strict':
return {
i_: [f_] + _int_types,
f_: [c_] + _float_types,
c_: _complex_types,
**{t: [] for t in _jax_types}
}
else:
raise ValueError(
f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}")
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> dict[JAXType, set[JAXType]]:
lattice = _type_promotion_lattice(jax_numpy_dtype_promotion)
upper_bounds = {node: {node} for node in lattice}
for n in lattice:
while True:
new_upper_bounds = set().union(*(lattice[b] for b in upper_bounds[n]))
if n in new_upper_bounds:
raise ValueError(f"cycle detected in type promotion lattice for node {n}")
if new_upper_bounds.issubset(upper_bounds[n]):
break
upper_bounds[n] |= new_upper_bounds
return upper_bounds
_lattice_upper_bounds: dict[str, dict[JAXType, set[JAXType]]] = {
'standard': _make_lattice_upper_bounds('standard'),
'strict': _make_lattice_upper_bounds('strict'),
}
class TypePromotionError(ValueError):
pass
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXType:
"""Compute the least upper bound of a set of nodes.
Args:
nodes: sequence of entries from _jax_types + _weak_types
Returns:
the _jax_type representing the least upper bound of the input nodes
on the promotion lattice.
"""
# This function computes the least upper bound of a set of nodes N within a partially
# ordered set defined by the lattice generated above.
# Given a partially ordered set S, let the set of upper bounds of n β S be
# UB(n) β‘ {m β S | n β€ m}
# Further, for a set of nodes N β S, let the set of common upper bounds be given by
# CUB(N) β‘ {a β S | β b β N: a β UB(b)}
# Then the least upper bound of N is defined as
# LUB(N) β‘ {c β CUB(N) | β d β CUB(N), c β€ d}
# The definition of an upper bound implies that c β€ d if and only if d β UB(c),
# so the LUB can be expressed:
# LUB(N) = {c β CUB(N) | β d β CUB(N): d β UB(c)}
# or, equivalently:
# LUB(N) = {c β CUB(N) | CUB(N) β UB(c)}
# By definition, LUB(N) has a cardinality of 1 for a partially ordered set.
# Note a potential algorithmic shortcut: from the definition of CUB(N), we have
# β c β N: CUB(N) β UB(c)
# So if N β© CUB(N) is nonempty, if follows that LUB(N) = N β© CUB(N).
N = set(nodes)
UB = _lattice_upper_bounds[jax_numpy_dtype_promotion]
try:
bounds = [UB[n] for n in N]
except KeyError:
dtype = next(n for n in N if n not in UB)
raise ValueError(f"{dtype=} is not a valid dtype for JAX type promotion.")
CUB = set.intersection(*bounds)
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
if len(LUB) == 1:
return LUB.pop()
elif len(LUB) == 0:
if config.numpy_dtype_promotion.value == 'strict':
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting "
"inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.")
elif any(n in _float8_dtypes for n in nodes):
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path. To avoid unintended promotion, 8-bit floats do not support "
"implicit promotion. If you'd like your inputs to be promoted to another type, "
"you can do so explicitly using e.g. x.astype('float32')")
elif any(n in _int4_dtypes for n in nodes):
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path. To avoid unintended promotion, 4-bit integers do not support "
"implicit promotion. If you'd like your inputs to be promoted to another type, "
"you can do so explicitly using e.g. x.astype('int32')")
else:
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path. Try explicitly casting inputs to the desired output type.")
raise TypePromotionError(msg)
else:
# If we get here, it means the lattice is ill-formed.
raise TypePromotionError(
f"Internal Type Promotion error: {nodes} do not have a unique least upper bound "
f"on the specified lattice; options are {LUB}. This is an unexpected error in "
"JAX's internal logic; please report it to the JAX maintainers."
)
def is_weakly_typed(x: Any) -> bool:
if type(x) in _weak_types:
return True
try:
return x.aval.weak_type
except AttributeError:
return False
def is_python_scalar(x: Any) -> bool:
try:
return x.aval.weak_type and np.ndim(x) == 0
except AttributeError:
return type(x) in python_scalar_dtypes
def check_valid_dtype(dtype: DType) -> None:
if dtype not in _jax_dtype_set:
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
def dtype(x: Any, *, canonicalize: bool = False) -> DType:
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
if x is None:
raise ValueError(f"Invalid argument to dtype: {x}.")
is_type = isinstance(x, type)
if is_type and x in python_scalar_dtypes:
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif is_type and _issubclass(x, np.generic):
return np.dtype(x)
elif issubdtype(getattr(x, 'dtype', None), extended):
dt = x.dtype
else:
try:
dt = np.result_type(x)
except TypeError as err:
raise TypeError(f"Cannot determine dtype of {x}") from err
if dt not in _jax_dtype_set and not issubdtype(dt, extended):
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
# TODO(jakevdp): fix return type annotation and remove this ignore.
return canonicalize_dtype(dt, allow_extended_dtype=True) if canonicalize else dt # type: ignore[return-value]
def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
out_dtype = dtypes[0]
out_weak_type = weak_types[0]
elif len(set(dtypes)) == 1 and not all(weak_types):
# Trivial promotion case. This allows extended dtypes through.
out_dtype = dtypes[0]
out_weak_type = False
elif all(weak_types) and config.numpy_dtype_promotion.value != 'strict':
# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16).
# TODO(jakevdp): explore removing this special case.
result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
*{_jax_type(dtype, False) for dtype in dtypes})
out_dtype = dtype(result_type)
out_weak_type = True
else:
result_type = _least_upper_bound(config.numpy_dtype_promotion.value,
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
out_dtype = dtype(result_type)
out_weak_type = any(result_type is t for t in _weak_types)
return out_dtype, (out_dtype != bool_) and out_weak_type
@overload
def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType, bool]: ...
@overload
def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ...
@overload
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]: ...
@export
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]:
"""Convenience function to apply JAX argument dtype promotion.
Args:
return_weak_type_flag : if True, then return a ``(dtype, weak_type)`` tuple.
If False, just return `dtype`
Returns:
dtype or (dtype, weak_type) depending on the value of the ``return_weak_type`` argument.
"""
if len(args) == 0:
raise ValueError("at least one array or dtype is required")
dtype: DType | ExtendedDType
dtype, weak_type = _lattice_result_type(*(float_ if arg is None else arg for arg in args))
if weak_type:
dtype = canonicalize_dtype(
_default_types['f' if dtype in _custom_float_dtypes else dtype.kind])
else:
dtype = canonicalize_dtype(dtype, allow_extended_dtype=True)
# TODO(jakevdp): fix return type annotation and remove this ignore.
return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value]
def check_user_dtype_supported(dtype, fun_name=None):
if issubdtype(dtype, extended):
return
# Avoid using `dtype in [...]` because of numpy dtype equality overloading.
if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
return
np_dtype = np.dtype(dtype)
is_custom_dtype = np_dtype.type in [*_custom_float_scalar_types, int4, uint4]
if np_dtype.kind not in "biufc" and not is_custom_dtype:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
if dtype is not None and np_dtype != canonicalize_dtype(dtype):
msg = ("Explicitly requested dtype {} {} is not available, "
"and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
"environment variable. "
"See https://github.com/google/jax#current-gotchas for more.")
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3)
def safe_to_cast(input_dtype_or_value: Any,
output_dtype_or_value: Any) -> bool:
"""Check if a dtype/value is safe to cast to another dtype/value
Args:
input_dtype_or_value: a dtype or value (to be passed to result_type)
representing the source dtype.
output_dtype_or_value: a dtype or value (to be passed to result_type)
representing the target dtype.
Returns:
boolean representing whether the values are safe to cast according to
default type promotion semantics.
Raises:
TypePromotionError: if the inputs have differing types and no type promotion
path under the current jax_numpy_dtype_promotion setting.
Examples:
>>> safe_to_cast('int32', 'float64')
True
>>> safe_to_cast('float64', 'int32')
False
>>> safe_to_cast('float32', 'complex64')
True
>>> safe_to_cast('complex64', 'float64')
False
"""
input_dtype = dtype(input_dtype_or_value, canonicalize=True)
output_dtype = dtype(output_dtype_or_value, canonicalize=True)
if input_dtype == output_dtype:
return True
# We deliberately use output_dtype rather than output_dtype_or_value here:
# this effectively treats the output dtype as always strongly-typed.
return result_type(input_dtype_or_value, output_dtype) == output_dtype