jax.numpy.astype

Contents

jax.numpy.astype#

jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[source]#

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Parameters:
  • x (ArrayLike)

  • dtype (DTypeLike | None)

  • copy (bool)

  • device (xc.Device | Sharding | None)

Return type:

Array