jax.Array.astype#
- abstract Array.astype(dtype, copy=False, device=None)[source]#
Copy the array and cast to a specified dtype.
This is implemented via
jax.lax.convert_element_type()
, which may have slightly different behavior thannumpy.ndarray.astype()
in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.