jax.Array.astype

Contents

jax.Array.astype#

abstract Array.astype(dtype, copy=False, device=None)[source]#

Copy the array and cast to a specified dtype.

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.ndarray.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Parameters:
  • self (Array)

  • dtype (DTypeLike | None)

  • copy (bool)

  • device (xc.Device | Sharding | None)

Return type:

Array