jax.lax.bitcast_convert_type#
- jax.lax.bitcast_convert_type(operand, new_dtype)[source]#
Elementwise bitcast.
Wraps XLA’s BitcastConvertType operator, which performs a bit cast from one type to another.
The output shape depends on the size of the input and output dtypes with the following logic:
if new_dtype.itemsize == operand.dtype.itemsize: output_shape = operand.shape if new_dtype.itemsize < operand.dtype.itemsize: output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize) if new_dtype.itemsize > operand.dtype.itemsize: assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize output_shape = operand.shape[:-1]
- Parameters:
operand (ArrayLike) – an array or scalar value to be cast
new_dtype (DTypeLike) – the new type. Should be a NumPy type.
- Returns:
An array of shape output_shape (see above) and type new_dtype, constructed from the same bits as operand.
- Return type: