jax.numpy.result_type

Contents

jax.numpy.result_type#

jax.numpy.result_type(*args)[source]#

Returns the type that results from applying the NumPy

LAX-backend implementation of numpy.result_type().

Original docstring below.

type promotion rules to the arguments.

Type promotion in NumPy works similarly to the rules in languages like C++, with some slight differences. When both scalars and arrays are used, the array’s type takes precedence and the actual value of the scalar is taken into account.

For example, calculating 3*a, where a is an array of 32-bit floats, intuitively should result in a 32-bit float output. If the 3 is a 32-bit integer, the NumPy rules indicate it can’t convert losslessly into a 32-bit float, so a 64-bit float should be the result type. By examining the value of the constant, ‘3’, we see that it fits in an 8-bit integer, which can be cast losslessly into the 32-bit float.

Returns:

out – The result type.

Return type:

dtype

Parameters:

args (Any)