jax.lax.bitcast_convert_type

jax.lax.bitcast_convert_type#

jax.lax.bitcast_convert_type(operand, new_dtype)[source]#

Elementwise bitcast.

Wraps XLA’s BitcastConvertType operator, which performs a bit cast from one type to another.

The output shape depends on the size of the input and output dtypes with the following logic:

if new_dtype.itemsize == operand.dtype.itemsize:
  output_shape = operand.shape
if new_dtype.itemsize < operand.dtype.itemsize:
  output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize)
if new_dtype.itemsize > operand.dtype.itemsize:
  assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize
  output_shape = operand.shape[:-1]
Parameters:
  • operand (ArrayLike) – an array or scalar value to be cast

  • new_dtype (DTypeLike) – the new type. Should be a NumPy type.

Returns:

An array of shape output_shape (see above) and type new_dtype, constructed from the same bits as operand.

Return type:

Array