jax.numpy.result_type

jax.numpy.result_type(*args)[source]
Returns the type that results from applying the NumPy

type promotion rules to the arguments.

LAX-backend implementation of result_type(). Original docstring below.

result_type(*arrays_and_dtypes)

Type promotion in NumPy works similarly to the rules in languages like C++, with some slight differences. When both scalars and arrays are used, the array’s type takes precedence and the actual value of the scalar is taken into account.

For example, calculating 3*a, where a is an array of 32-bit floats, intuitively should result in a 32-bit float output. If the 3 is a 32-bit integer, the NumPy rules indicate it can’t convert losslessly into a 32-bit float, so a 64-bit float should be the result type. By examining the value of the constant, ‘3’, we see that it fits in an 8-bit integer, which can be cast losslessly into the 32-bit float.

Returns
outdtype

The result type.

dtype, promote_types, min_scalar_type, can_cast

New in version 1.6.0.

The specific algorithm used is as follows.

Categories are determined by first checking which of boolean, integer (int/uint), or floating point (float/complex) the maximum kind of all the arrays and the scalars are.

If there are only scalars or the maximum category of the scalars is higher than the maximum category of the arrays, the data types are combined with promote_types() to produce the return value.

Otherwise, min_scalar_type is called on each array, and the resulting data types are all combined with promote_types() to produce the return value.

The set of int values is not a subset of the uint values for types with the same number of bits, something not reflected in min_scalar_type(), but handled as a special case in result_type.

>>> np.result_type(3, np.arange(7, dtype='i1'))
dtype('int8')
>>> np.result_type('i4', 'c8')
dtype('complex128')
>>> np.result_type(3.0, -2)
dtype('float64')