jax.numpy.copy#
- jax.numpy.copy(a, order=None)[source]#
Return a copy of the array.
JAX implementation of
numpy.copy()
.- Parameters:
a (ArrayLike) – arraylike object to copy
order (str | None | None) – not implemented in JAX
- Returns:
a copy of the input array
a
.- Return type:
See also
jax.numpy.array()
: create an array with or without a copy.jax.Array.copy()
: same function accessed as an array method.
Examples
Since JAX arrays are immutable, in most cases explicit array copies are not necessary. One exception is when using a function with donated arguments (see the
donate_argnums
argument tojax.jit()
).>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) >>> x = jnp.arange(4) >>> y = f(x) >>> print(y) [0 2 4 6]
Because we marked
x
as being donated, the original array is no longer available:>>> print(x) Traceback (most recent call last): RuntimeError: Array has been deleted with shape=int32[4].
In situations like this, an explicit copy will let you keep access to the original buffer:
>>> x = jnp.arange(4) >>> y = f(x.copy()) >>> print(y) [0 2 4 6] >>> print(x) [0 1 2 3]