jax.numpy.from_dlpack

Contents

jax.numpy.from_dlpack#

jax.numpy.from_dlpack(x, /, *, device=None, copy=None)[source]#

Create a NumPy array from an object implementing the __dlpack__

LAX-backend implementation of numpy.from_dlpack().

Note

While JAX arrays are always immutable, dlpack buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a jax Array is constructed from a dlpack buffer and the buffer is later modified in-place, it may lead to undefined behavior when using the associated JAX array.

Original docstring below.

protocol. Generally, the returned NumPy array is a read-only view of the input object. See [1] and [2] for more details.

Parameters:

x (object) – A Python object that implements the __dlpack__ and __dlpack_device__ methods.

Returns:

out

Return type:

ndarray

References

Parameters:
  • device (xc.Device | Sharding | None)

  • copy (bool | None)