jax.Array.at#
- abstract property Array.at[source]#
Helper property for index update functionality.
The
at
property provides a functionally pure equivalent of in-place array modifications.In particular:
Alternate syntax
Equivalent In-place expression
x = x.at[idx].set(y)
x[idx] = y
x = x.at[idx].add(y)
x[idx] += y
x = x.at[idx].subtract(y)
x[idx] -= y
x = x.at[idx].multiply(y)
x[idx] *= y
x = x.at[idx].divide(y)
x[idx] /= y
x = x.at[idx].power(y)
x[idx] **= y
x = x.at[idx].min(y)
x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y)
x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc)
ufunc.at(x, idx)
x = x.at[idx].get()
x = x[idx]
None of the
x.at
expressions modify the originalx
; instead they return a modified copy ofx
. However, inside ajit()
compiled function, expressions likex = x.at[idx].set(y)
are guaranteed to be applied in-place.Unlike NumPy in-place operations such as
x[idx] += y
, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the
mode
parameter (see below).- Parameters:
mode (str) –
Specify out-of-bound indexing mode. Options are:
"promise_in_bounds"
: (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that out-of-bounds indices inget()
will be clipped, and out-of-bounds indices inset()
,add()
, etc. will be dropped."clip"
: clamp out of bounds indices into valid range."drop"
: ignore out-of-bound indices."fill"
: alias for"drop"
. For get(), the optionalfill_value
argument specifies the value that will be returned.See
jax.lax.GatherScatterMode
for more details.
indices_are_sorted (bool) – If True, the implementation will assume that the indices passed to
at[]
are sorted in ascending order, which can lead to more efficient execution on some backends.unique_indices (bool) – If True, the implementation will assume that the indices passed to
at[]
are unique, which can result in more efficient execution on some backends.fill_value (Any) – Only applies to the
get()
method: the fill value to return for out-of-bounds slices when mode is'fill'
. Ignored otherwise. Defaults toNaN
for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, andTrue
for booleans.
Examples
>>> x = jnp.arange(5.0) >>> x Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32) >>> x.at[10].add(10) # out-of-bounds indices are ignored Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') Array([ 0., 1., 2., 3., 14.], dtype=float32) >>> x.at[2].get() Array(2., dtype=float32) >>> x.at[20].get() # out-of-bounds indices clipped Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value Array(-1., dtype=float32)