jax.numpy.astype

Contents

jax.numpy.astype#

jax.numpy.astype(x, dtype, /, *, copy=True)[source]#

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Parameters:
  • x (ArrayLike) –

  • dtype (DTypeLike | None) –

  • copy (bool) –

Return type:

Array