jax.numpy.asarray

Contents

jax.numpy.asarray#

jax.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None)[source]#

Convert an object to a JAX array.

JAX implementation of numpy.asarray().

Parameters:
  • a (Any) – an object that is convertible to an array. This includes JAX arrays, NumPy arrays, Python scalars, Python collections like lists and tuples, objects with an __array__ method, and objects supporting the Python buffer protocol.

  • dtype (DTypeLike | None | None) – optionally specify the dtype of the output array. If not specified it will be inferred from the input.

  • order (str | None | None) – not implemented in JAX

  • copy (bool | None | None) – optional boolean specifying the copy mode. If True, then always return a copy. If False, then error if a copy is necessary. Default is None, which will only copy when necessary.

  • device (xc.Device | Sharding | None | None) – optional Device or Sharding to which the created array will be committed.

Returns:

A JAX array constructed from the input.

Return type:

Array

See also

Examples

Constructing JAX arrays from Python scalars:

>>> jnp.asarray(True)
Array(True, dtype=bool)
>>> jnp.asarray(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.asarray(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.asarray(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

Constructing JAX arrays from Python collections:

>>> jnp.asarray([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.asarray([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.asarray(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

Constructing JAX arrays from NumPy arrays:

>>> jnp.asarray(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.asarray(pybuffer)
Array([2, 3, 5, 7], dtype=int32)