jax.lax package

jax.lax is a library of primitives operations that underpins libraries such as jax.numpy. Transformation rules, such as JVP and batching rules, are typically defined as transformations on jax.lax primitives.

Many of the primitives are thin wrappers around equivalent XLA operations, described by the XLA operation semantics documentation. In a few cases JAX diverges from XLA, usually to ensure that the set of operations is closed under the operation of JVP and transpose rules.

Where possible, prefer to use libraries such as jax.numpy instead of using jax.lax directly. The jax.numpy API follows NumPy, and is therefore more stable and less likely to change than the jax.lax API.


abs(x) Elementwise absolute value: \(|x|\).
add(x, y) Elementwise addition: \(x + y\).
acos(x) Elementwise arc cosine: \(\mathrm{acos}(x)\).
asin(x) Elementwise arc sine: \(\mathrm{asin}(x)\).
atan(x) Elementwise arc tangent: \(\mathrm{atan}(x)\).
atan2(x, y) Elementwise arc tangent of two variables: \(\mathrm{atan}({x \over y})\).
batch_matmul(lhs, rhs[, precision]) Batch matrix multiplication.
bessel_i0e(x) Exponentially scaled modified Bessel function of order 0: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\)
bessel_i1e(x) Exponentially scaled modified Bessel function of order 1: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\)
betainc(a, b, x) Elementwise regularized incomplete beta integral.
bitcast_convert_type(operand, new_dtype) Elementwise bitcast.
bitwise_not(x) Elementwise NOT: \(\neg x\).
bitwise_and(x, y) Elementwise AND: \(x \wedge y\).
bitwise_or(x, y) Elementwise OR: \(x \vee y\).
bitwise_xor(x, y) Elementwise exclusive OR: \(x \oplus y\).
broadcast(operand, sizes) Broadcasts an array, adding new major dimensions.
broadcasted_iota(dtype, shape, dimension) Convenience wrapper around iota.
broadcast_in_dim(operand, shape, …) Wraps XLA’s BroadcastInDim operator.
ceil(x) Elementwise ceiling: \(\left\lceil x \right\rceil\).
clamp(min, x, max) Elementwise clamp.
collapse(operand, start_dimension, …)
complex(x, y) Elementwise make complex number: \(x + jy\).
concatenate(operands, dimension) Concatenates a sequence of arrays along dimension.
conj(x) Elementwise complex conjugate function: \(\overline{x}\).
conv(lhs, rhs, window_strides, padding[, …]) Convenience wrapper around conv_general_dilated.
convert_element_type(operand, new_dtype) Elementwise cast.
conv_general_dilated(lhs, rhs, …[, …]) General n-dimensional convolution operator, with optional dilation.
conv_with_general_padding(lhs, rhs, …[, …]) Convenience wrapper around conv_general_dilated.
conv_transpose(lhs, rhs, strides, padding[, …]) Convenience wrapper for calculating the N-d convolution “transpose”.
cos(x) Elementwise cosine: \(\mathrm{cos}(x)\).
cosh(x) Elementwise hyperbolic cosine: \(\mathrm{cosh}(x)\).
digamma(x) Elementwise digamma: \(\psi(x)\).
div(x, y) Elementwise division: \(x \over y\).
dot(lhs, rhs[, precision]) Vector/vector, matrix/vector, and matrix/matrix multiplication.
dot_general(lhs, rhs, dimension_numbers[, …]) More general contraction operator.
dynamic_index_in_dim(operand, index[, axis, …]) Convenience wrapper around dynamic_slice to perform int indexing.
dynamic_slice(operand, start_indices, …) Wraps XLA’s DynamicSlice operator.
dynamic_slice_in_dim(operand, start_index, …) Convenience wrapper around dynamic_slice applying to one dimension.
dynamic_update_index_in_dim(operand, update, …)
dynamic_update_slice_in_dim(operand, update, …)
eq(x, y) Elementwise equals: \(x = y\).
erf(x) Elementwise error function: \(\mathrm{erf}(x)\).
erfc(x) Elementwise complementary error function: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\).
erf_inv(x) Elementwise inverse error function: \(\mathrm{erf}^{-1}(x)\).
exp(x) Elementwise exponential: \(e^x\).
expm1(x) Elementwise \(e^{x - 1}\).
fft(x, fft_type, fft_lengths)
floor(x) Elementwise floor: \(\left\lfloor x \right\rfloor\).
full(shape, fill_value[, dtype]) Returns an array of shape filled with fill_value.
full_like(x, fill_value[, dtype, shape]) Create a full array like np.full based on the example array x.
gather(operand, start_indices, …) Gather operator.
ge(x, y) Elementwise greater-than-or-equals: \(x \geq y\).
gt(x, y) Elementwise greater-than: \(x > y\).
igamma(a, x) Elementwise regularized incomplete gamma function.
igammac(a, x) Elementwise complementary regularized incomplete gamma function.
imag(x) Elementwise extract imaginary part: \(\mathrm{Im}(x)\).
index_in_dim(operand, index[, axis, keepdims]) Convenience wrapper around slice to perform int indexing.
index_take(src, idxs, axes)
iota(dtype, size) Wraps XLA’s Iota operator.
is_finite(x) Elementwise \(\mathrm{isfinite}\).
le(x, y) Elementwise less-than-or-equals: \(x \leq y\).
lt(x, y) Elementwise less-than: \(x < y\).
lgamma(x) Elementwise log gamma: \(\mathrm{log}(\Gamma(x))\).
log(x) Elementwise natural logarithm: \(\mathrm{log}(x)\).
log1p(x) Elementwise \(\mathrm{log}(1 + x)\).
max(x, y) Elementwise maximum: \(\mathrm{max}(x, y)\)
min(x, y) Elementwise minimum: \(\mathrm{min}(x, y)\)
mul(x, y) Elementwise multiplication: \(x \times y\).
ne(x, y) Elementwise not-equals: \(x \neq y\).
neg(x) Elementwise negation: \(-x\).
nextafter(x1, x2) Returns the next representable value after x1 in the direction of x2.
pad(operand, padding_value, padding_config) Wraps XLA’s Pad operator.
pow(x, y) Elementwise power: \(x^y\).
real(x) Elementwise extract real part: \(\mathrm{Re}(x)\).
reciprocal(x) Elementwise reciprocal: \(1 \over x\).
reduce(operand, init_value, computation, …) Wraps XLA’s Reduce operator.
reduce_window(operand, init_value, …) Wraps XLA’s ReduceWindow operator.
reshape(operand, new_sizes[, dimensions]) Wraps XLA’s Reshape operator.
rem(x, y) Elementwise remainder: \(x \bmod y\).
rev(operand, dimensions) Wraps XLA’s Rev operator.
round(x) Elementwise round.
rsqrt(x) Elementwise reciprocal square root: :math:`1 over sqrt{x}.
scatter(operand, scatter_indices, updates, …) Scatter-update operator.
scatter_add(operand, scatter_indices, …) Scatter-add operator.
select(pred, on_true, on_false) Wraps XLA’s Select operator.
shift_left(x, y) Elementwise left shift: \(x \ll y\).
shift_right_arithmetic(x, y) Elementwise arithmetic right shift: \(x \gg y\).
shift_right_logical(x, y) Elementwise logical right shift: \(x \gg y\).
slice(operand, start_indices, limit_indices) Wraps XLA’s Slice operator.
slice_in_dim(operand, start_index, limit_index) Convenience wrapper around slice applying to only one dimension.
sign(x) Elementwise sign.
sin(x) Elementwise sine: \(\mathrm{sin}(x)\).
sinh(x) Elementwise hyperbolic sine: \(\mathrm{sinh}(x)\).
sort(operand[, dimension]) Wraps XLA’s Sort operator.
sort_key_val(keys, values[, dimension])
sqrt(x) Elementwise square root: \(\sqrt{x}\).
square(x) Elementwise square: \(x^2\).
sub(x, y) Elementwise subtraction: \(x - y\).
tan(x) Elementwise tangent: \(\mathrm{tan}(x)\).
tie_in(x, y) Gives y a fake data dependence on x.
transpose(operand, permutation) Wraps XLA’s Transpose operator.

Control flow operators

cond(pred, true_operand, true_fun, …) Conditionally apply true_fun or false_fun.
fori_loop(lower, upper, body_fun, init_val) Loop from lower to upper by reduction to while_loop.
map(f, xs) Map a function over leading array axes.
scan(f, init, xs[, length]) Scan a function over leading array axes while carrying along state.
while_loop(cond_fun, body_fun, init_val) Call body_fun repeatedly in a loop while cond_fun is True.

Custom gradient operators

stop_gradient(x) Stops gradient computation.
custom_linear_solve(matvec, b, solve[, …]) Perform a matrix-free linear solve with implicitly defined gradients.
custom_root(f, initial_guess, solve, …) Differentiably solve for a roots of a function.

Parallel operators

Parallelism support is experimental.

all_gather(x, axis_name) Gather values of x across all replicas.
all_to_all(x, axis_name, split_axis, concat_axis) Materialize the mapped axis and map a different axis.
psum(x, axis_name) Compute an all-reduce sum on x over the pmapped axis axis_name.
pmax(x, axis_name) Compute an all-reduce max on x over the pmapped axis axis_name.
pmin(x, axis_name) Compute an all-reduce min on x over the pmapped axis axis_name.
ppermute(x, axis_name, perm) Perform a collective permutation according to the permutation perm.
pswapaxes(x, axis_name, axis) Swap the pmapped axis axis_name with the unmapped axis axis.
axis_index(axis_name) Return the index along the pmapped axis axis_name.