Source code for jax._src.numpy.lax_numpy

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pytype: skip-file
"""
Implements the NumPy API, using the primitives in :mod:`jax.lax`.

NumPy operations are implemented in Python in terms of the primitive operations
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
implemented in terms of :mod:`jax.lax` operations, we do not need to define
transformation rules such as gradient or batching rules. Instead,
transformations for NumPy primitives can be derived from the transformation
rules for the underlying :code:`lax` primitives.
"""
from __future__ import annotations

import builtins
import collections
from collections.abc import Sequence
from functools import partial
import math
import operator
import types
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol,
                    TypeVar, Union)
from textwrap import dedent as _dedent
import warnings

import numpy as np
import opt_einsum

import jax
from jax import jit
from jax import errors
from jax import lax
from jax.sharding import Sharding, SingleDeviceSharding
from jax.tree_util import tree_leaves, tree_flatten, tree_map

from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src.custom_derivatives import custom_jvp
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import xla_bridge
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
                              _sort_le_comparator, PrecisionLike)
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc, xla_extension_version
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
                           ceil_of_ratio, partition_list,
                           canonicalize_axis as _canonicalize_axis,
                           NumpyComplexWarning)

newaxis = None
T = TypeVar('T')


# Like core.canonicalize_shape, but also accept int-like (non-sequence)
# arguments for `shape`.
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
  if (not isinstance(shape, (tuple, list)) and
      (getattr(shape, 'ndim', None) == 0 or ndim(shape) == 0)):
    return core.canonicalize_shape((shape,), context)  # type: ignore
  else:
    return core.canonicalize_shape(shape, context)  # type: ignore

# Common docstring additions:

_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, a :class:`~jax.lax.Precision` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple
of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument.
"""

# Some objects below rewrite their __module__ attribute to this name.
_PUBLIC_MODULE_NAME = "jax.numpy"

# NumPy constants

pi = np.pi
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
nan = np.nan

# NumPy utility functions

get_printoptions = np.get_printoptions
printoptions = np.printoptions
set_printoptions = np.set_printoptions

[docs] @util.implements(np.iscomplexobj) def iscomplexobj(x: Any) -> bool: if x is None: return False try: typ = x.dtype.type except AttributeError: typ = asarray(x).dtype.type return issubdtype(typ, complexfloating)
shape = _shape = np.shape ndim = _ndim = np.ndim size = np.size def _dtype(x: Any) -> DType: return dtypes.dtype(x, canonicalize=True) # At present JAX doesn't have a reason to distinguish between scalars and arrays # in its object system. Further, we want JAX scalars to have the same type # promotion behaviors as JAX arrays. Rather than introducing a new type of JAX # scalar object with JAX promotion behaviors, instead we make the JAX scalar # types return JAX arrays when instantiated. class _ScalarMeta(type): dtype: np.dtype def __hash__(self) -> int: return hash(self.dtype.type) def __eq__(self, other: Any) -> bool: return id(self) == id(other) or self.dtype.type == other def __ne__(self, other: Any) -> bool: return not (self == other) def __call__(self, x: Any) -> Array: return asarray(x, dtype=self.dtype) def __instancecheck__(self, instance: Any) -> bool: return isinstance(instance, self.dtype.type) def _abstractify_scalar_meta(x): raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.") api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), {"dtype": np.dtype(np_scalar_type)}) meta.__module__ = _PUBLIC_MODULE_NAME return meta bool_ = _make_scalar_type(np.bool_) uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) float64 = double = _make_scalar_type(np.float64) complex64 = csingle = _make_scalar_type(np.complex64) complex128 = cdouble = _make_scalar_type(np.complex128) int_ = int32 if dtypes.int_ == np.int32 else int64 uint = uint32 if dtypes.uint == np.uint32 else uint64 float_: Any = float32 if dtypes.float_ == np.float32 else float64 complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128 generic = np.generic number = np.number inexact = np.inexact complexfloating = np.complexfloating floating = np.floating integer = np.integer signedinteger = np.signedinteger unsignedinteger = np.unsignedinteger flexible = np.flexible character = np.character object_ = np.object_ iinfo = dtypes.iinfo finfo = dtypes.finfo dtype = np.dtype can_cast = dtypes.can_cast promote_types = dtypes.promote_types ComplexWarning = NumpyComplexWarning array_str = np.array_str array_repr = np.array_repr save = np.save savez = np.savez @util.implements(np.dtype) def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False, copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" if dtypes.issubdtype(obj, dtypes.extended): return obj # type: ignore[return-value] if obj is None: obj = dtypes.float_ elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes: obj = _DEFAULT_TYPEMAP[obj] return np.dtype(obj, align=align, copy=copy) ### utility functions _DEFAULT_TYPEMAP: dict[type, _ScalarMeta] = { bool: bool_, int: int_, float: float_, complex: complex_, } _lax_const = lax_internal._const def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: """ Convert integer-typed val to specified integer dtype, clipping to dtype range rather than wrapping. Args: val: value to be converted dtype: dtype of output Returns: equivalent of val in new dtype Examples -------- Normal integer type conversion will wrap: >>> val = jnp.uint32(0xFFFFFFFF) >>> val.astype('int32') Array(-1, dtype=int32) This function clips to the values representable in the new type: >>> _convert_and_clip_integer(val, 'int32') Array(2147483647, dtype=int32) """ val = val if isinstance(val, Array) else asarray(val) dtype = dtypes.canonicalize_dtype(dtype) if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)): raise TypeError("_convert_and_clip_integer only accepts integer dtypes.") val_dtype = dtypes.canonicalize_dtype(val.dtype) if val_dtype != val.dtype: # TODO(jakevdp): this is a weird corner case; need to figure out how to handle it. # This happens in X32 mode and can either come from a jax value created in another # context, or a Python integer converted to int64. pass min_val = _lax_const(val, max(iinfo(dtype).min, iinfo(val_dtype).min)) max_val = _lax_const(val, min(iinfo(dtype).max, iinfo(val_dtype).max)) return clip(val, min_val, max_val).astype(dtype)
[docs] @util.implements(np.load, update_doc=False) def load(*args: Any, **kwargs: Any) -> Array: # The main purpose of this wrapper is to recover bfloat16 data types. # Note: this will only work for files created via np.save(), not np.savez(). out = np.load(*args, **kwargs) if isinstance(out, np.ndarray): # numpy does not recognize bfloat16, so arrays are serialized as void16 if out.dtype == 'V2': out = out.view(bfloat16) try: out = asarray(out) except (TypeError, AssertionError): # Unsupported dtype pass return out
### implementations of numpy functions in terms of lax
[docs] @util.implements(np.fmin, module='numpy') @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2)
[docs] @util.implements(np.fmax, module='numpy') @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2)
[docs] @util.implements(np.issubdtype) def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2)
[docs] @util.implements(np.isscalar) def isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() return dtypes.is_python_scalar(element) or np.isscalar(element)
iterable = np.iterable
[docs] @util.implements(np.result_type) def result_type(*args: Any) -> DType: return dtypes.result_type(*args)
[docs] @util.implements(np.trunc, module='numpy') @jit def trunc(x: ArrayLike) -> Array: util.check_arraylike('trunc', x) return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """ preferred_element_type : dtype, optional If specified, accumulate results and return a result of the given data type. If not specified, the function instead follows the numpy convention of always accumulating results and returning an inexact dtype. """ @partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type']) def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, preferred_element_type: DTypeLike | None = None) -> Array: if ndim(x) != 1 or ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") if preferred_element_type is None: # if unspecified, promote to inexact following NumPy's default for convolutions. x, y = util.promote_dtypes_inexact(x, y) else: # otherwise cast to same type but otherwise preserve input dtypes x, y = util.promote_dtypes(x, y) if len(x) == 0 or len(y) == 0: raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.") out_order = slice(None) if op == 'correlate': y = ufuncs.conj(y) if len(x) < len(y): x, y = y, x out_order = slice(None, None, -1) elif op == 'convolve': if len(x) < len(y): x, y = y, x y = flip(y) if mode == 'valid': padding = [(0, 0)] elif mode == 'same': padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)] elif mode == 'full': padding = [(y.shape[0] - 1, y.shape[0] - 1)] else: raise ValueError("mode must be one of ['full', 'same', 'valid']") result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,), padding, precision=precision, preferred_element_type=preferred_element_type) return result[0, 0, out_order]
[docs] @util.implements(np.convolve, lax_description=_PRECISION_DOC, extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: util.check_arraylike("convolve", a, v) return _conv(asarray(a), asarray(v), mode=mode, op='convolve', precision=precision, preferred_element_type=preferred_element_type)
[docs] @util.implements(np.correlate, lax_description=_PRECISION_DOC, extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: util.check_arraylike("correlate", a, v) return _conv(asarray(a), asarray(v), mode=mode, op='correlate', precision=precision, preferred_element_type=preferred_element_type)
[docs] @util.implements(np.histogram_bin_edges) def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: None | Array | Sequence[ArrayLike] = None, weights: ArrayLike | None = None) -> Array: del weights # unused, because string bins is not supported. if isinstance(bins, str): raise NotImplementedError("string values for `bins` not implemented.") util.check_arraylike("histogram_bin_edges", a, bins) arr = asarray(a) dtype = dtypes.to_inexact_dtype(arr.dtype) if _ndim(bins) == 1: return asarray(bins, dtype=dtype) bins_int = core.concrete_or_error(operator.index, bins, "bins argument of histogram_bin_edges") if range is None: range = [arr.min(), arr.max()] range = asarray(range, dtype=dtype) if shape(range) != (2,): raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}") range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]), where(reductions.ptp(range) == 0, range[1] + 0.5, range[1])) assert range is not None return linspace(range[0], range[1], bins_int + 1, dtype=dtype)
[docs] @util.implements(np.histogram) def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Sequence[ArrayLike] | None = None, weights: ArrayLike | None = None, density: bool | None = None) -> tuple[Array, Array]: if weights is None: util.check_arraylike("histogram", a, bins) a, = util.promote_dtypes_inexact(a) weights = ones_like(a) else: util.check_arraylike("histogram", a, bins, weights) if shape(a) != shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) bin_edges = histogram_bin_edges(a, bins, range, weights) bin_idx = searchsorted(bin_edges, a, side='right') bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx) counts = zeros(len(bin_edges), weights.dtype).at[bin_idx].add(weights)[1:] if density: bin_widths = diff(bin_edges) counts = counts / bin_widths / counts.sum() return counts, bin_edges
[docs] @util.implements(np.histogram2d) def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, density: bool | None = None) -> tuple[Array, Array, Array]: util.check_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: N = 1 if N != 1 and N != 2: x_edges = y_edges = asarray(bins) bins = [x_edges, y_edges] sample = transpose(asarray([x, y])) hist, edges = histogramdd(sample, bins, range, weights, density) return hist, edges[0], edges[1]
[docs] @util.implements(np.histogramdd) def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, density: bool | None = None) -> tuple[Array, list[Array]]: if weights is None: util.check_arraylike("histogramdd", sample) sample, = util.promote_dtypes_inexact(sample) else: util.check_arraylike("histogramdd", sample, weights) if shape(weights) != shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) N, D = shape(sample) if range is not None and ( len(range) != D or any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type] raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence " f"of {D} pairs or Nones; got {range=}") try: num_bins = len(bins) # type: ignore[arg-type] except TypeError: # when bin_size is integer, the same bin is used for each dimension bins_per_dimension: list[ArrayLike] = D * [bins] # type: ignore[assignment] else: if num_bins != D: raise ValueError("should be a bin for each dimension.") bins_per_dimension = list(bins) # type: ignore[arg-type] bin_idx_by_dim: list[Array] = [] bin_edges_by_dim: list[Array] = [] for i in builtins.range(D): range_i = None if range is None else range[i] bin_edges = histogram_bin_edges(sample[:, i], bins_per_dimension[i], range_i, weights) bin_idx = searchsorted(bin_edges, sample[:, i], side='right') bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx) bin_idx_by_dim.append(bin_idx) bin_edges_by_dim.append(bin_edges) nbins = tuple(len(bin_edges) + 1 for bin_edges in bin_edges_by_dim) dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim] xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip') hist = bincount(xy, weights, length=math.prod(nbins)) hist = reshape(hist, nbins) core = D*(slice(1, -1),) hist = hist[core] if density: hist = hist.astype(sample.dtype) hist /= hist.sum() for norm in ix_(*dedges): hist /= norm return hist, bin_edges_by_dim
_ARRAY_VIEW_DOC = """ The JAX version of this function may in some cases return a copy rather than a view of the input. """
[docs] @util.implements(np.transpose, lax_description=_ARRAY_VIEW_DOC) def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: util.check_arraylike("transpose", a) axes_ = list(range(ndim(a))[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_] return lax.transpose(a, axes_)
[docs] @util.implements(getattr(np, "permute_dims", None)) def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: util.check_arraylike("permute_dims", a) return lax.transpose(a, axes)
[docs] @util.implements(getattr(np, 'matrix_transpose', None)) def matrix_transpose(x: ArrayLike, /) -> Array: """Transposes the last two dimensions of x. Parameters ---------- x : array_like Input array. Must have ``x.ndim >= 2``. Returns ------- xT : Array Transposed array. """ util.check_arraylike("matrix_transpose", x) ndim = np.ndim(x) if ndim < 2: raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}") axes = (*range(ndim - 2), ndim - 1, ndim - 2) return lax.transpose(x, axes)
[docs] @util.implements(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: util.check_arraylike("rot90", m) if np.ndim(m) < 2: raise ValueError("rot90 requires its first argument to have ndim at least " f"two, but got first argument of shape {np.shape(m)}, " f"which has ndim {np.ndim(m)}") ax1, ax2 = axes ax1 = _canonicalize_axis(ax1, ndim(m)) ax2 = _canonicalize_axis(ax2, ndim(m)) if ax1 == ax2: raise ValueError("Axes must be different") # same as numpy error k = k % 4 if k == 0: return asarray(m) elif k == 2: return flip(flip(m, ax1), ax2) else: perm = list(range(ndim(m))) perm[ax1], perm[ax2] = perm[ax2], perm[ax1] if k == 1: return transpose(flip(m, ax2), perm) else: return flip(transpose(m, perm), ax2)
[docs] @util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC) def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: util.check_arraylike("flip", m) return _flip(asarray(m), reductions._ensure_optional_axes(axis))
@partial(jit, static_argnames=('axis',)) def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: if axis is None: return lax.rev(m, list(range(len(shape(m))))) axis = _ensure_index_tuple(axis) return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
[docs] @util.implements(np.fliplr, lax_description=_ARRAY_VIEW_DOC) def fliplr(m: ArrayLike) -> Array: util.check_arraylike("fliplr", m) return _flip(asarray(m), 1)
[docs] @util.implements(np.flipud, lax_description=_ARRAY_VIEW_DOC) def flipud(m: ArrayLike) -> Array: util.check_arraylike("flipud", m) return _flip(asarray(m), 0)
[docs] @util.implements(np.iscomplex) @jit def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0))
[docs] @util.implements(np.isreal) @jit def isreal(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0))
[docs] @util.implements(np.angle) @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: re = ufuncs.real(z) im = ufuncs.imag(z) dtype = _dtype(re) if not issubdtype(dtype, inexact) or ( issubdtype(_dtype(z), floating) and ndim(z) == 0): dtype = dtypes.canonicalize_dtype(float_) re = lax.convert_element_type(re, dtype) im = lax.convert_element_type(im, dtype) result = lax.atan2(im, re) return ufuncs.degrees(result) if deg else result
[docs] @util.implements(np.diff) @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, append: ArrayLike | None = None) -> Array: util.check_arraylike("diff", a) arr = asarray(a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff") if n == 0: return arr if n < 0: raise ValueError(f"order must be non-negative but got {n}") if arr.ndim == 0: raise ValueError(f"diff requires input that is at least one dimensional; got {a}") nd = arr.ndim axis = _canonicalize_axis(axis, nd) combined: list[Array] = [] if prepend is not None: util.check_arraylike("diff", prepend) if not ndim(prepend): shape = list(arr.shape) shape[axis] = 1 prepend = broadcast_to(prepend, tuple(shape)) combined.append(asarray(prepend)) combined.append(arr) if append is not None: util.check_arraylike("diff", append) if not ndim(append): shape = list(arr.shape) shape[axis] = 1 append = broadcast_to(append, tuple(shape)) combined.append(asarray(append)) if len(combined) > 1: arr = concatenate(combined, axis) slice1 = [slice(None)] * nd slice2 = [slice(None)] * nd slice1[axis] = slice(1, None) slice2[axis] = slice(None, -1) slice1_tuple = tuple(slice1) slice2_tuple = tuple(slice2) op = ufuncs.not_equal if arr.dtype == np.bool_ else ufuncs.subtract for _ in range(n): arr = op(arr[slice1_tuple], arr[slice2_tuple]) return arr
_EDIFF1D_DOC = """\ Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary`` loses precision. """
[docs] @util.implements(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: util.check_arraylike("ediff1d", ary) arr = ravel(ary) result = lax.sub(arr[1:], arr[:-1]) if to_begin is not None: util.check_arraylike("ediff1d", to_begin) result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result)) if to_end is not None: util.check_arraylike("ediff1d", to_end) result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result
[docs] @util.implements(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=('axis', 'edge_order')) def gradient(f: ArrayLike, *varargs: ArrayLike, axis: int | Sequence[int] | None = None, edge_order: int | None = None) -> Array | list[Array]: if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") a, *spacing = util.promote_args_inexact("gradient", f, *varargs) def gradient_along_axis(a, h, axis): sliced = partial(lax.slice_in_dim, a, axis=axis) a_grad = concatenate(( (sliced(1, 2) - sliced(0, 1)), # upper edge (sliced(2, None) - sliced(None, -2)) * 0.5, # inner (sliced(-1, None) - sliced(-2, -1)), # lower edge ), axis) return a_grad / h if axis is None: axis_tuple = tuple(range(a.ndim)) else: axis_tuple = tuple(_canonicalize_axis(i, a.ndim) for i in _ensure_index_tuple(axis)) if len(axis_tuple) == 0: return [] if min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2: raise ValueError("Shape of array too small to calculate " "a numerical gradient, " "at least 2 elements are required.") if len(spacing) == 0: dx: Sequence[ArrayLike] = [1.0] * len(axis_tuple) elif len(spacing) == 1: dx = list(spacing) * len(axis_tuple) elif len(spacing) == len(axis_tuple): dx = list(spacing) else: TypeError(f"Invalid number of spacing arguments {len(spacing)} for {axis=}") if ndim(dx[0]) != 0: raise NotImplementedError("Non-constant spacing not implemented") a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis_tuple, dx)] return a_grad[0] if len(axis_tuple) == 1 else a_grad
[docs] @util.implements(np.isrealobj) def isrealobj(x: Any) -> bool: return not iscomplexobj(x)
[docs] @util.implements(np.reshape, lax_description=_ARRAY_VIEW_DOC) def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array: __tracebackhide__ = True util.check_arraylike("reshape", a) try: # forward to method for ndarrays return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: pass return asarray(a).reshape(newshape, order=order)
[docs] @util.implements(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: util.check_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") return reshape(a, (size(a),), order)
[docs] @util.implements(np.ravel_multi_index) def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) util.check_arraylike("ravel_multi_index", *multi_index) multi_index_arr = [asarray(i) for i in multi_index] for index in multi_index_arr: if mode == 'raise': core.concrete_or_error(array, index, "The error occurred because ravel_multi_index was jit-compiled" " with mode='raise'. Use mode='wrap' or mode='clip' instead.") if not issubdtype(_dtype(index), integer): raise TypeError("only int indices permitted") if mode == "raise": if any(reductions.any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)): raise ValueError("invalid entry in coordinates array") elif mode == "clip": multi_index_arr = [clip(i, 0, d - 1) for i, d in zip(multi_index_arr, dims)] elif mode == "wrap": multi_index_arr = [i % d for i, d in zip(multi_index_arr, dims)] else: raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'") if order == "F": strides = np.cumprod((1,) + dims[:-1]) elif order == "C": strides = np.cumprod((1,) + dims[1:][::-1])[::-1] else: raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'") result = array(0, dtype=(multi_index_arr[0].dtype if multi_index_arr else dtypes.canonicalize_dtype(int_))) for i, s in zip(multi_index_arr, strides): result = result + i * int(s) return result
_UNRAVEL_INDEX_DOC = """\ Unlike numpy's implementation of unravel_index, negative indices are accepted and out-of-bounds indices are clipped into the valid range. """
[docs] @util.implements(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: util.check_arraylike("unravel_index", indices) indices_arr = asarray(indices) # Note: we do not convert shape to an array, because it may be passed as a # tuple of weakly-typed values, and asarray() would strip these weak types. try: shape = list(shape) except TypeError: shape = [shape] if any(ndim(s) != 0 for s in shape): raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") out_indices = [0] * len(shape) for i, s in reversed(list(enumerate(shape))): indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s) oob_pos = indices_arr > 0 oob_neg = indices_arr < -1 return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in safe_zip(shape, out_indices))
[docs] @util.implements(np.resize) @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: util.check_arraylike("resize", a) new_shape = _ensure_index_tuple(new_shape) if any(dim_length < 0 for dim_length in new_shape): raise ValueError("all elements of `new_shape` must be non-negative") arr = ravel(a) new_size = math.prod(new_shape) if arr.size == 0 or new_size == 0: return zeros_like(arr, shape=new_shape) repeats = ceil_of_ratio(new_size, arr.size) arr = tile(arr, repeats)[:new_size] return reshape(arr, new_shape)
[docs] @util.implements(np.squeeze, lax_description=_ARRAY_VIEW_DOC) def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: util.check_arraylike("squeeze", a) return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
@partial(jit, static_argnames=('axis',), inline=True) def _squeeze(a: Array, axis: tuple[int]) -> Array: if axis is None: a_shape = shape(a) if not core.is_constant_shape(a_shape): # We do not even know the rank of the output if the input shape is not known raise ValueError("jnp.squeeze with axis=None is not supported with shape polymorphism") axis = tuple(i for i, d in enumerate(a_shape) if d == 1) return lax.squeeze(a, axis)
[docs] @util.implements(np.expand_dims) def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: util.check_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis)
[docs] @util.implements(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: util.check_arraylike("swapaxes", a) perm = np.arange(ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm))
[docs] @util.implements(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: util.check_arraylike("moveaxis", a) return _moveaxis(asarray(a), _ensure_index_tuple(source), _ensure_index_tuple(destination))
@partial(jit, static_argnames=('source', 'destination'), inline=True) def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array: source = tuple(_canonicalize_axis(i, ndim(a)) for i in source) destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination) if len(source) != len(destination): raise ValueError("Inconsistent number of elements: {} vs {}" .format(len(source), len(destination))) perm = [i for i in range(ndim(a)) if i not in source] for dest, src in sorted(zip(destination, source)): perm.insert(dest, src) return lax.transpose(a, perm)
[docs] @util.implements(np.isclose) @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: a, b = util.promote_args("isclose", a, b) dtype = _dtype(a) if issubdtype(dtype, inexact): if issubdtype(dtype, complexfloating): dtype = util._complex_elem_type(dtype) rtol = lax.convert_element_type(rtol, dtype) atol = lax.convert_element_type(atol, dtype) out = lax.le( lax.abs(lax.sub(a, b)), lax.add(atol, lax.mul(rtol, lax.abs(b)))) # This corrects the comparisons for infinite and nan values a_inf = ufuncs.isinf(a) b_inf = ufuncs.isinf(b) any_inf = ufuncs.logical_or(a_inf, b_inf) both_inf = ufuncs.logical_and(a_inf, b_inf) # Make all elements where either a or b are infinite to False out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) # Make all elements where both a or b are the same inf to True same_value = lax.eq(a, b) same_inf = ufuncs.logical_and(both_inf, same_value) out = ufuncs.logical_or(out, same_inf) # Make all elements where either a or b is NaN to False a_nan = ufuncs.isnan(a) b_nan = ufuncs.isnan(b) any_nan = ufuncs.logical_or(a_nan, b_nan) out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) if equal_nan: # Make all elements where both a and b is NaN to True both_nan = ufuncs.logical_and(a_nan, b_nan) out = ufuncs.logical_or(out, both_nan) return out else: return lax.eq(a, b)
def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: util.check_arraylike("interp", x, xp, fp) if shape(xp) != shape(fp) or ndim(xp) != 1: raise ValueError("xp and fp must be one-dimensional arrays of equal size") x_arr, xp_arr = util.promote_dtypes_inexact(x, xp) fp_arr, = util.promote_dtypes_inexact(fp) del x, xp, fp if isinstance(left, str): if left != 'extrapolate': raise ValueError("the only valid string value of `left` is " f"'extrapolate', but got: {left!r}") extrapolate_left = True else: extrapolate_left = False if isinstance(right, str): if right != 'extrapolate': raise ValueError("the only valid string value of `right` is " f"'extrapolate', but got: {right!r}") extrapolate_right = True else: extrapolate_right = False if dtypes.issubdtype(x_arr.dtype, np.complexfloating): raise ValueError("jnp.interp: complex x values not supported.") if period is not None: if ndim(period) != 0: raise ValueError(f"period must be a scalar; got {period}") period = ufuncs.abs(period) x_arr = x_arr % period xp_arr = xp_arr % period xp_arr, fp_arr = lax.sort_key_val(xp_arr, fp_arr) xp_arr = concatenate([xp_arr[-1:] - period, xp_arr, xp_arr[:1] + period]) fp_arr = concatenate([fp_arr[-1:], fp_arr, fp_arr[:1]]) i = clip(searchsorted(xp_arr, x_arr, side='right'), 1, len(xp_arr) - 1) df = fp_arr[i] - fp_arr[i - 1] dx = xp_arr[i] - xp_arr[i - 1] delta = x_arr - xp_arr[i - 1] epsilon = np.spacing(np.finfo(xp_arr.dtype).eps) dx0 = lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small. f = where(dx0, fp_arr[i - 1], fp_arr[i - 1] + (delta / where(dx0, 1, dx)) * df) if not extrapolate_left: assert not isinstance(left, str) left_arr: ArrayLike = fp_arr[0] if left is None else left if period is None: f = where(x_arr < xp_arr[0], left_arr, f) if not extrapolate_right: assert not isinstance(right, str) right_arr: ArrayLike = fp_arr[-1] if right is None else right if period is None: f = where(x_arr > xp_arr[-1], right_arr, f) return f
[docs] @util.implements(np.interp, lax_description=_dedent(""" In addition to constant interpolation supported by NumPy, jnp.interp also supports left='extrapolate' and right='extrapolate' to indicate linear extrapolation instead.""")) def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: static_argnames = [] if isinstance(left, str) or left is None: static_argnames.append('left') if isinstance(right, str) or right is None: static_argnames.append('right') if period is None: static_argnames.append('period') jitted_interp = jit(_interp, static_argnames=static_argnames) return jitted_interp(x, xp, fp, left, right, period)
@overload # type: ignore[no-overload-impl] def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: ... @overload def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, / ,*, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> Array: ... @overload def where(condition: ArrayLike, x: ArrayLike | None = None, y: ArrayLike | None = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> Array | tuple[Array, ...]: ... _DEPRECATED_WHERE_ARG = object()
[docs] @util.implements(np.where, # type: ignore[no-redef] lax_description=_dedent(""" At present, JAX does not support JIT-compilation of the single-argument form of :py:func:`jax.numpy.where` because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can use the optional ``size`` keyword to statically specify the expected size of the output.\n\n Special care is needed when the ``x`` or ``y`` input to :py:func:`jax.numpy.where` could have a value of NaN. Specifically, when a gradient is taken with :py:func:`jax.grad` (reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the gradient, regardless of the value of ``condition``. More information on this behavior and workarounds is available in the JAX FAQ: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where"""), extra_params=_dedent(""" size : int, optional Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first ``size`` elements of the result will be returned. If there are fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``. fill_value : array_like, optional When ``size`` is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with ``fill_value``, which defaults to zero.""")) def where( acondition = None, if_true = None, if_false = None, /, *, size=None, fill_value=None, # Deprecated keyword-only names. condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG, y = _DEPRECATED_WHERE_ARG ) -> Array | tuple[Array, ...]: if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG or y is not _DEPRECATED_WHERE_ARG): # TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires. warnings.warn( "Passing condition, x, or y to jax.numpy.where via keyword arguments " "is deprecated.", DeprecationWarning, stacklevel=2, ) if condition is not _DEPRECATED_WHERE_ARG: if acondition is not None: raise ValueError("condition should be a positional-only argument") acondition = condition if x is not _DEPRECATED_WHERE_ARG: if if_true is not None: raise ValueError("x should be a positional-only argument") if_true = x if y is not _DEPRECATED_WHERE_ARG: if if_false is not None: raise ValueError("y should be a positional-only argument") if_false = y if if_true is None and if_false is None: util.check_arraylike("where", acondition) return nonzero(acondition, size=size, fill_value=fill_value) else: util.check_arraylike("where", acondition, if_true, if_false) if size is not None or fill_value is not None: raise ValueError("size and fill_value arguments cannot be used in three-term where function.") return util._where(acondition, if_true, if_false)
[docs] @util.implements(np.select) def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], default: ArrayLike = 0, ) -> Array: if len(condlist) != len(choicelist): msg = "condlist must have length equal to choicelist ({} vs {})" raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") choices = util.promote_dtypes(default, *choicelist) choicelist = choices[1:] output = choices[0] for cond, choice in zip(condlist[::-1], choicelist[::-1]): output = where(cond, choice, output) return output
[docs] @util.implements(np.bincount, lax_description="""\ Jax adds the optional `length` parameter which specifies the output length, and defaults to ``x.max() + 1``. It must be specified for bincount to be compiled with non-static operands. Values larger than the specified length will be discarded. If `length` is specified, `minlength` will be ignored. Additionally, while ``np.bincount`` raises an error if the input array contains negative values, ``jax.numpy.bincount`` clips negative values to zero. """) def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: util.check_arraylike("bincount", x) if not issubdtype(_dtype(x), integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") if ndim(x) != 1: raise ValueError("only 1-dimensional input supported.") minlength = core.concrete_or_error(operator.index, minlength, "The error occurred because of argument 'minlength' of jnp.bincount.") if length is None: x_arr = core.concrete_or_error(asarray, x, "The error occurred because of argument 'x' of jnp.bincount. " "To avoid this error, pass a static `length` argument.") length = max(minlength, x_arr.size and int(x_arr.max()) + 1) else: length = core.concrete_dim_or_error(length, "The error occurred because of argument 'length' of jnp.bincount.") if weights is None: weights = np.array(1, dtype=int_) elif shape(x) != shape(weights): raise ValueError("shape of weights must match shape of x.") return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
@overload def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... @overload def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ...
[docs] @util.implements(getattr(np, "broadcast_shapes", None)) def broadcast_shapes(*shapes): if not shapes: return () shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes] return lax.broadcast_shapes(*shapes)
[docs] @util.implements(np.broadcast_arrays, lax_description="""\ The JAX version does not necessarily return a view of the input. """) def broadcast_arrays(*args: ArrayLike) -> list[Array]: return util._broadcast_arrays(*args)
[docs] @util.implements(np.broadcast_to, lax_description="""\ The JAX version does not necessarily return a view of the input. """) def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: return util._broadcast_to(array, shape)
def _split(op: str, ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: util.check_arraylike(op, ary) ary = asarray(ary) axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`") size = ary.shape[axis] if (isinstance(indices_or_sections, (tuple, list)) or isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): indices_or_sections = [ core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1") for i_s in indices_or_sections] split_indices = [0] + list(indices_or_sections) + [size] else: if core.is_symbolic_dim(indices_or_sections): raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is " "not supported") num_sections: int = core.concrete_or_error(int, indices_or_sections, f"in jax.numpy.{op} argument 1") part_size, r = divmod(size, num_sections) if r == 0: split_indices = [i * part_size for i in range(num_sections + 1)] elif op == "array_split": split_indices = ( [i * (part_size + 1) for i in range(r + 1)] + [i * part_size + ((r + 1) * (part_size + 1) - 1) for i in range(num_sections - r)]) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] for i in split_indices] starts, ends = [0] * ndim(ary), shape(ary) _subval = lambda x, i, v: subvals(x, [(i, v)]) return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])]
[docs] @util.implements(np.split, lax_description=_ARRAY_VIEW_DOC) def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: return _split("split", ary, indices_or_sections, axis=axis)
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, int | ArrayLike], list[Array]]: @util.implements(getattr(np, op), update_doc=False) def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: # for 1-D array, hsplit becomes vsplit nonlocal axis util.check_arraylike(op, ary) a = asarray(ary) if axis == 1 and len(a.shape) == 1: axis = 0 return _split(op, ary, indices_or_sections, axis=axis) return f vsplit = _split_on_axis("vsplit", axis=0) hsplit = _split_on_axis("hsplit", axis=1) dsplit = _split_on_axis("dsplit", axis=2)
[docs] @util.implements(np.array_split) def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: return _split("array_split", ary, indices_or_sections, axis=axis)
[docs] @util.implements(np.clip, skip_params=['out']) @jit def clip(a: ArrayLike, a_min: ArrayLike | None = None, a_max: ArrayLike | None = None, out: None = None) -> Array: util.check_arraylike("clip", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.clip is not supported.") if a_min is None and a_max is None: raise ValueError("At most one of a_min and a_max may be None") if a_min is not None: a = ufuncs.maximum(a_min, a) if a_max is not None: a = ufuncs.minimum(a_max, a) return asarray(a)
[docs] @util.implements(np.around, skip_params=['out']) @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: util.check_arraylike("round", a) decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: raise NotImplementedError("The 'out' argument to jnp.round is not supported.") dtype = _dtype(a) if issubdtype(dtype, integer): if decimals < 0: raise NotImplementedError( "integer np.round not implemented for decimals < 0") return asarray(a) # no-op on integer types def _round_float(x: ArrayLike) -> Array: if decimals == 0: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) # TODO(phawkins): the strategy of rescaling the value isn't necessarily a # good one since we may be left with an incorrectly rounded value at the # end due to precision problems. As a workaround for float16, convert to # float32, x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x factor = _lax_const(x, 10 ** decimals) out = lax.div(lax.round(lax.mul(x, factor), lax.RoundingMethod.TO_NEAREST_EVEN), factor) return lax.convert_element_type(out, dtype) if dtype == np.float16 else out if issubdtype(dtype, complexfloating): return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a))) else: return _round_float(a)
around = round round_ = round
[docs] @util.implements(np.fix, skip_params=['out']) @jit def fix(x: ArrayLike, out: None = None) -> Array: util.check_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") zero = _lax_const(x, 0) return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x))
[docs] @util.implements(np.nan_to_num) @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, neginf: ArrayLike | None = None) -> Array: del copy util.check_arraylike("nan_to_num", x) dtype = _dtype(x) if not issubdtype(dtype, inexact): return asarray(x) if issubdtype(dtype, complexfloating): return lax.complex( nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf), nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf)) info = finfo(dtypes.canonicalize_dtype(dtype)) posinf = info.max if posinf is None else posinf neginf = info.min if neginf is None else neginf out = where(ufuncs.isnan(x), asarray(nan, dtype=dtype), x) out = where(ufuncs.isposinf(out), asarray(posinf, dtype=dtype), out) out = where(ufuncs.isneginf(out), asarray(neginf, dtype=dtype), out) return out
[docs] @util.implements(np.allclose) @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: util.check_arraylike("allclose", a, b) return reductions.all(isclose(a, b, rtol, atol, equal_nan))
_NONZERO_DOC = """\ Because the size of the output of ``nonzero`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which must be specified statically for ``jnp.nonzero`` to be used within some of JAX's transformations. """ _NONZERO_EXTRA_PARAMS = """ size : int, optional If specified, the indices of the first ``size`` True elements will be returned. If there are fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. fill_value : array_like, optional When ``size`` is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with ``fill_value``, which defaults to zero. """
[docs] @util.implements(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: util.check_arraylike("nonzero", a) arr = asarray(a) del a if ndim(arr) == 0: # Added 2023 Dec 6 warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()", DeprecationWarning, stacklevel=2) arr = atleast_1d(arr) mask = arr if arr.dtype == bool else (arr != 0) calculated_size = mask.sum() if size is None else size calculated_size = core.concrete_dim_or_error(calculated_size, "The size argument of jnp.nonzero must be statically specified " "to use jnp.nonzero within JAX transformations.") if arr.size == 0 or calculated_size == 0: return tuple(zeros(calculated_size, int) for dim in arr.shape) flat_indices = reductions.cumsum( bincount(reductions.cumsum(mask), length=calculated_size)) # type: ignore[arg-type] strides: np.ndarray = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(int_) out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape)) if fill_value is not None: fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,) if any(_shape(val) != () for val in fill_value_tup): raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}") fill_mask = arange(calculated_size) >= mask.sum() out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out)) return out
[docs] @util.implements(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike] = None) -> Array: return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
[docs] @util.implements(np.unwrap) @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: util.check_arraylike("unwrap", p) p = asarray(p) if issubdtype(p.dtype, np.complexfloating): raise ValueError("jnp.unwrap does not support complex inputs.") if p.shape[axis] == 0: return util.promote_dtypes_inexact(p)[0] if discont is None: discont = period / 2 interval = period / 2 dd = diff(p, axis=axis) ddmod = ufuncs.mod(dd + interval, period) - interval ddmod = where((ddmod == -interval) & (dd > 0), interval, ddmod) ph_correct = where(ufuncs.abs(dd) < discont, 0, ddmod - dd) up = concatenate(( lax.slice_in_dim(p, 0, 1, axis=axis), lax.slice_in_dim(p, 1, None, axis=axis) + reductions.cumsum(ph_correct, axis=axis) ), axis=axis) return up
### Padding PadValueLike = Union[T, Sequence[T], Sequence[Sequence[T]]] PadValue = tuple[tuple[T, T], ...] class PadStatFunc(Protocol): def __call__(self, array: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False) -> Array: ... def _broadcast_to_pairs(nvals: PadValueLike, nd: int, name: str) -> PadValue: try: nvals = np.asarray(tree_map( lambda x: core.concrete_or_error(None, x, context=f"{name} argument of jnp.pad"), nvals)) except ValueError as e: # In numpy 1.24 if "array has an inhomogeneous shape" in str(e): raise TypeError(f'`{name}` entries must be the same shape: {nvals}') from e raise def as_scalar_dim(v): if core.is_dim(v) or not np.shape(v): return v else: raise TypeError(f'`{name}` entries must be the same shape: {nvals}') if nvals.shape == (nd, 2): # ((before_1, after_1), ..., (before_N, after_N)) return tuple((as_scalar_dim(nval[0]), as_scalar_dim(nval[1])) for nval in nvals) elif nvals.shape == (1, 2): # ((before, after),) v1_2 = as_scalar_dim(nvals[0, 0]), as_scalar_dim(nvals[0, 1]) return tuple(v1_2 for i in range(nd)) elif nvals.shape == (2,): # (before, after) (not in the numpy docstring but works anyway) v1_2 = as_scalar_dim(nvals[0]), as_scalar_dim(nvals[1]) return tuple(v1_2 for i in range(nd)) elif nvals.shape == (1,): # (pad,) v = as_scalar_dim(nvals[0]) return tuple((v, v) for i in range(nd)) elif nvals.shape == (): # pad v = as_scalar_dim(nvals.flat[0]) return tuple((v, v) for i in range(nd)) else: raise ValueError(f"jnp.pad: {name} with {nd=} has unsupported shape {nvals.shape}. " f"Valid shapes are ({nd}, 2), (1, 2), (2,), (1,), or ().") def _check_no_padding(axis_padding: tuple[Any, Any], mode: str): if (axis_padding[0] > 0 or axis_padding[1] > 0): msg = "Cannot apply '{}' padding to empty axis" raise ValueError(msg.format(mode)) def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array: nd = ndim(array) constant_values = broadcast_to(constant_values, (nd, 2)) constant_values = lax_internal._convert_element_type( constant_values, array.dtype, dtypes.is_weakly_typed(array)) for i in range(nd): widths = [(0, 0, 0)] * nd widths[i] = (pad_width[i][0], 0, 0) array = lax.pad(array, constant_values[i, 0], widths) widths[i] = (0, pad_width[i][1], 0) array = lax.pad(array, constant_values[i, 1], widths) return array def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array: for i in range(ndim(array)): if array.shape[i] == 0: _check_no_padding(pad_width[i], "wrap") continue size = array.shape[i] repeats, (left_remainder, right_remainder) = np.divmod(pad_width[i], size) total_repeats = repeats.sum() + 1 parts = [] if left_remainder: parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)] parts += total_repeats * [array] if right_remainder: parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)] array = lax.concatenate(parts, dimension=i) return array def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int], mode: str, reflect_type: str) -> Array: assert mode in ("symmetric", "reflect") assert reflect_type in ("even", "odd") for i in range(ndim(array)): if array.shape[i] == 0: _check_no_padding(pad_width[i], mode) continue n = array.shape[i] offset = 1 if (mode == "reflect" and n > 1) else 0 def build_padding(array, padding, before): if before: edge = lax.slice_in_dim(array, 0, 1, axis=i) else: edge = lax.slice_in_dim(array, -1, None, axis=i) while padding > 0: curr_pad = min(padding, n - offset) padding -= curr_pad if before: start = offset stop = offset + curr_pad else: start = -(curr_pad + offset) stop = None if (mode == "symmetric" or n == 1) else -1 x = lax.slice_in_dim(array, start, stop, axis=i) x = flip(x, axis=i) if reflect_type == 'odd': x = 2 * edge - x if n > 1: if before: edge = lax.slice_in_dim(x, 0, 1, axis=i) else: edge = lax.slice_in_dim(x, -1, None, axis=i) if before: array = lax.concatenate([x, array], dimension=i) else: array = lax.concatenate([array, x], dimension=i) return array array = build_padding(array, pad_width[i][0], before=True) array = build_padding(array, pad_width[i][1], before=False) return array def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array: nd = ndim(array) for i in range(nd): if array.shape[i] == 0: _check_no_padding(pad_width[i], "edge") continue n = array.shape[i] npad_before, npad_after = pad_width[i] edge_before = lax.slice_in_dim(array, 0, 1, axis=i) pad_before = repeat(edge_before, npad_before, axis=i) edge_after = lax.slice_in_dim(array, n-1, n, axis=i) pad_after = repeat(edge_after, npad_after, axis=i) array = lax.concatenate([pad_before, array, pad_after], dimension=i) return array def _pad_linear_ramp(array: Array, pad_width: PadValue[int], end_values: PadValue[ArrayLike]) -> Array: for axis in range(ndim(array)): edge_before = lax.slice_in_dim(array, 0, 1, axis=axis) edge_after = lax.slice_in_dim(array, -1, None, axis=axis) ramp_before = linspace( start=end_values[axis][0], stop=edge_before.squeeze(axis), # Dimension is replaced by linspace num=pad_width[axis][0], endpoint=False, dtype=array.dtype, axis=axis ) ramp_before = lax_internal._convert_element_type( ramp_before, weak_type=dtypes.is_weakly_typed(array)) ramp_after = linspace( start=end_values[axis][1], stop=edge_after.squeeze(axis), # Dimension is replaced by linspace num=pad_width[axis][1], endpoint=False, dtype=array.dtype, axis=axis ) ramp_after = lax_internal._convert_element_type( ramp_after, weak_type=dtypes.is_weakly_typed(array)) # Reverse linear space in appropriate dimension ramp_after = flip(ramp_after, axis) array = lax.concatenate([ramp_before, array, ramp_after], dimension=axis) return array def _pad_stats(array: Array, pad_width: PadValue[int], stat_length: PadValue[int] | None, stat_func: PadStatFunc) -> Array: nd = ndim(array) for i in range(nd): if stat_length is None: stat_before = stat_func(array, axis=i, keepdims=True) stat_after = stat_before else: array_length = array.shape[i] length_before, length_after = stat_length[i] if length_before == 0 or length_after == 0: raise ValueError("stat_length of 0 yields no value for padding") # Limit stat_length to length of array. length_before = min(length_before, array_length) length_after = min(length_after, array_length) slice_before = lax.slice_in_dim(array, 0, length_before, axis=i) slice_after = lax.slice_in_dim(array, -length_after, None, axis=i) stat_before = stat_func(slice_before, axis=i, keepdims=True) stat_after = stat_func(slice_after, axis=i, keepdims=True) if np.issubdtype(array.dtype, np.integer): stat_before = round(stat_before) stat_after = round(stat_after) stat_before = lax_internal._convert_element_type( stat_before, array.dtype, dtypes.is_weakly_typed(array)) stat_after = lax_internal._convert_element_type( stat_after, array.dtype, dtypes.is_weakly_typed(array)) npad_before, npad_after = pad_width[i] pad_before = repeat(stat_before, npad_before, axis=i) pad_after = repeat(stat_after, npad_after, axis=i) array = lax.concatenate([pad_before, array, pad_after], dimension=i) return array def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array: # Note: jax.numpy.empty = jax.numpy.zeros for i in range(ndim(array)): shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:] pad_before = empty_like(array, shape=shape_before) shape_after = array.shape[:i] + (pad_width[i][1],) + array.shape[i + 1:] pad_after = empty_like(array, shape=shape_after) array = lax.concatenate([pad_before, array, pad_after], dimension=i) return array def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any], **kwargs) -> Array: pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") padded = _pad_constant(array, pad_width, asarray(0)) for axis in range(ndim(padded)): padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs) return padded @partial(jit, static_argnums=(1, 2, 4, 5, 6)) def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, constant_values: ArrayLike, stat_length: PadValueLike[int], end_values: PadValueLike[ArrayLike], reflect_type: str): array = asarray(array) nd = ndim(array) if nd == 0: return array stat_funcs: dict[str, PadStatFunc] = { "maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median } pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width") pad_width_arr = np.array(pad_width) if pad_width_arr.shape != (nd, 2): raise ValueError(f"Expected pad_width to have shape {(nd, 2)}; got {pad_width_arr.shape}.") if np.any(pad_width_arr < 0): raise ValueError("index can't contain negative values") if mode == "constant": return _pad_constant(array, pad_width, asarray(constant_values)) elif mode == "wrap": return _pad_wrap(array, pad_width) elif mode in ("symmetric", "reflect"): return _pad_symmetric_or_reflect(array, pad_width, str(mode), reflect_type) elif mode == "edge": return _pad_edge(array, pad_width) elif mode == "linear_ramp": end_values = _broadcast_to_pairs(end_values, nd, "end_values") return _pad_linear_ramp(array, pad_width, end_values) elif mode in stat_funcs: if stat_length is not None: stat_length = _broadcast_to_pairs(stat_length, nd, "stat_length") return _pad_stats(array, pad_width, stat_length, stat_funcs[str(mode)]) elif mode == "empty": return _pad_empty(array, pad_width) else: assert False, ("Should not be reached since pad already handled unsupported and" "not implemented modes")
[docs] @util.implements(np.pad, lax_description="""\ Unlike numpy, JAX "function" mode's argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, "function" mode's argument should modify a rank 1 array in-place.) """) def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: util.check_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): raise TypeError('`pad_width` must be of integral type.') if callable(mode): return _pad_func(asarray(array), pad_width, mode, **kwargs) allowed_kwargs = { 'empty': [], 'edge': [], 'wrap': [], 'constant': ['constant_values'], 'linear_ramp': ['end_values'], 'maximum': ['stat_length'], 'mean': ['stat_length'], 'median': ['stat_length'], 'minimum': ['stat_length'], 'reflect': ['reflect_type'], 'symmetric': ['reflect_type'], } try: unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode]) # type: ignore[call-overload] except KeyError: msg = "Unimplemented padding mode '{}' for np.pad." raise NotImplementedError(msg.format(mode)) if unsupported_kwargs: raise ValueError("unsupported keyword arguments for mode '{}': {}" .format(mode, unsupported_kwargs)) # Set default value if not given. constant_values = kwargs.get('constant_values', 0) stat_length = kwargs.get('stat_length', None) end_values = kwargs.get('end_values', 0) reflect_type = kwargs.get('reflect_type', "even") return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
### Array-creation functions
[docs] @util.implements(np.stack, skip_params=['out']) def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: if not len(arrays): raise ValueError("Need at least one array to stack.") if out is not None: raise NotImplementedError("The 'out' argument to jnp.stack is not supported.") if isinstance(arrays, (np.ndarray, Array)): axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: util.check_arraylike("stack", *arrays) shape0 = shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] for a in arrays: if shape(a) != shape0: raise ValueError("All input arrays must have the same shape.") new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype)
[docs] @util.implements(np.tile) def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: util.check_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: reps_tup: tuple[DimSize, ...] = (reps,) else: reps_tup = tuple(reps) # type: ignore[assignment,arg-type] reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep for rep in reps_tup) A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A) reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]), [k for pair in zip(reps_tup, A_shape) for k in pair]) return reshape(result, tuple(np.multiply(A_shape, reps_tup)))
def _concatenate_array(arr: ArrayLike, axis: int | None, dtype: DTypeLike | None = None) -> Array: # Fast path for concatenation when the input is an ndarray rather than a list. arr = asarray(arr, dtype=dtype) if arr.ndim == 0 or arr.shape[0] == 0: raise ValueError("Need at least one array to concatenate.") if axis is None: return lax.reshape(arr, (arr.size,)) if arr.ndim == 1: raise ValueError("Zero-dimensional arrays cannot be concatenated.") axis = _canonicalize_axis(axis, arr.ndim - 1) shape = arr.shape[1:axis + 1] + (arr.shape[0] * arr.shape[axis + 1],) + arr.shape[axis + 2:] dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)] return lax.reshape(arr, shape, dimensions)
[docs] @util.implements(np.concatenate) def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) util.check_arraylike("concatenate", *arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype) if ndim(arrays[0]) == 0: raise ValueError("Zero-dimensional arrays cannot be concatenated.") axis = _canonicalize_axis(axis, ndim(arrays[0])) if dtype is None: arrays_out = util.promote_dtypes(*arrays) else: arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] # lax.concatenate can be slow to compile for wide concatenations, so form a # tree of concatenations as a workaround especially for op-by-op mode. # (https://github.com/google/jax/issues/653). k = 16 while len(arrays_out) > 1: arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) for i in range(0, len(arrays_out), k)] return arrays_out[0]
[docs] @util.implements(getattr(np, "concat", None)) def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: util.check_arraylike("concat", *arrays) return jax.numpy.concatenate(arrays, axis=axis)
[docs] @util.implements(np.vstack) def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_2d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("vstack", *tup, emit_warning=True) arrs = [atleast_2d(m) for m in tup] return concatenate(arrs, axis=0, dtype=dtype)
[docs] @util.implements(np.hstack) def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_1d)(tup) arr0_ndim = arrs.ndim - 1 else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("hstack", *tup, emit_warning=True) arrs = [atleast_1d(m) for m in tup] arr0_ndim = arrs[0].ndim return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)
[docs] @util.implements(np.dstack) def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_3d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) arrs = [atleast_3d(m) for m in tup] return concatenate(arrs, axis=2, dtype=dtype)
[docs] @util.implements(np.column_stack) def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("column_stack", *tup, emit_warning=True) arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)] return concatenate(arrs, 1)
[docs] @util.implements(np.choose, skip_params=['out']) def choose(a: ArrayLike, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") util.check_arraylike('choose', a, *choices) if not issubdtype(_dtype(a), integer): raise ValueError("`a` array must be integer typed") N = len(choices) if mode == 'raise': arr: Array = core.concrete_or_error(asarray, a, "The error occurred because jnp.choose was jit-compiled" " with mode='raise'. Use mode='wrap' or mode='clip' instead.") if reductions.any((arr < 0) | (arr >= N)): raise ValueError("invalid entry in choice array") elif mode == 'wrap': arr = asarray(a) % N elif mode == 'clip': arr = clip(a, 0, N - 1) else: raise ValueError(f"mode={mode!r} not understood. Must be 'raise', 'wrap', or 'clip'") arr, *choices = broadcast_arrays(arr, *choices) return array(choices)[(arr,) + indices(arr.shape, sparse=True)]
def _atleast_nd(x: ArrayLike, n: int) -> Array: m = ndim(x) return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x) def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: if isinstance(xs, tuple): raise ValueError("jax.numpy.block does not allow tuples, got {}" .format(xs)) elif isinstance(xs, list): if len(xs) == 0: raise ValueError("jax.numpy.block does not allow empty list arguments") xs_tup, depths = unzip2([_block(x) for x in xs]) if any(d != depths[0] for d in depths[1:]): raise ValueError("Mismatched list depths in jax.numpy.block") rank = max(depths[0], max(ndim(x) for x in xs_tup)) xs_tup = tuple(_atleast_nd(x, rank) for x in xs_tup) return concatenate(xs_tup, axis=-depths[0]), depths[0] + 1 else: return asarray(xs), 1
[docs] @util.implements(np.block) @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: out, _ = _block(arrays) return out
@overload def atleast_1d() -> list[Array]: ... @overload def atleast_1d(x: ArrayLike, /) -> Array: ... @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
[docs] @util.implements(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_1d", *arys, emit_warning=True) if len(arys) == 1: return array(arys[0], copy=False, ndmin=1) else: return [array(arr, copy=False, ndmin=1) for arr in arys]
@overload def atleast_2d() -> list[Array]: ... @overload def atleast_2d(x: ArrayLike, /) -> Array: ... @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
[docs] @util.implements(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_2d", *arys, emit_warning=True) if len(arys) == 1: return array(arys[0], copy=False, ndmin=2) else: return [array(arr, copy=False, ndmin=2) for arr in arys]
@overload def atleast_3d() -> list[Array]: ... @overload def atleast_3d(x: ArrayLike, /) -> Array: ... @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
[docs] @util.implements(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_3d", *arys, emit_warning=True) if len(arys) == 1: arr = asarray(arys[0]) if arr.ndim == 0: arr = lax.expand_dims(arr, dimensions=(0, 1, 2)) elif arr.ndim == 1: arr = lax.expand_dims(arr, dimensions=(0, 2)) elif arr.ndim == 2: arr = lax.expand_dims(arr, dimensions=(2,)) return arr else: return [atleast_3d(arr) for arr in arys]
def _supports_buffer_protocol(obj): try: view = memoryview(obj) except TypeError: return False else: return True _ARRAY_DOC = """ This function will create arrays on JAX's default device. For control of the device placement of data, see :func:`jax.device_put`. More information is available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at https://jax.readthedocs.io/en/latest/faq.html). """ deprecations.register(__name__, "array-none")
[docs] @util.implements(np.array, lax_description=_ARRAY_DOC) def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0) -> Array: if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") # check if the given dtype is compatible with JAX dtypes.check_user_dtype_supported(dtype, "array") # Here we make a judgment call: we only return a weakly-typed array when the # input object itself is weakly typed. That ensures asarray(x) is a no-op # whenever x is weak, but avoids introducing weak types with something like # array([1, 2, 3]) weak_type = dtype is None and dtypes.is_weakly_typed(object) # Use device_put to avoid a copy for ndarray inputs. if (not copy and isinstance(object, np.ndarray) and (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)): # Keep the output uncommitted. return jax.device_put(object) # For Python scalar literals, call coerce_to_array to catch any overflow # errors. We don't use dtypes.is_python_scalar because we don't want this # triggering for traced values. We do this here because it matters whether or # not dtype is None. We don't assign the result because we want the raw object # to be used for type inference below. if isinstance(object, (bool, int, float, complex)): _ = dtypes.coerce_to_array(object, dtype) elif not isinstance(object, Array): # Check if object supports any of the data exchange protocols # (except dlpack, see data-apis/array-api#301). If it does, # consume the object as jax array and continue (but not return) so # that other array() arguments get processed against the input # object. # # Notice that data exchange protocols define dtype in the # corresponding data structures and it may not be available as # object.dtype. So, we'll resolve the protocols here before # evaluating object.dtype. if hasattr(object, '__jax_array__'): object = object.__jax_array__() elif hasattr(object, '__cuda_array_interface__'): if xla_extension_version >= 237: cai = object.__cuda_array_interface__ backend = xla_bridge.get_backend("cuda") object = xc._xla.cuda_array_interface_to_buffer(cai, backend) object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf, object) leaves = tree_leaves(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated(__name__, "array-none"): raise TypeError("None is not a valid value for jnp.array") warnings.warn( "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) leaves = tree_leaves(object) if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. try: dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_ except TypeError: # This happens if, e.g. one of the entries is a memoryview object. # This is rare, so we only handle it if the normal path fails. leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] dtype = dtypes._lattice_result_type(*leaves)[0] if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] out: ArrayLike if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in # https://github.com/google/jax/pull/6047. More correct would be to call # coerce_to_array on each leaf, but this may have performance implications. out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False) # type: ignore[arg-type] elif isinstance(object, Array): assert object.aval is not None out = _array_copy(object) if copy else object elif isinstance(object, (list, tuple)): if object: out = stack([asarray(elt, dtype=dtype) for elt in object]) else: out = np.array([], dtype=dtype) # type: ignore[arg-type] elif _supports_buffer_protocol(object): out = np.array(memoryview(object), copy=copy) else: raise TypeError(f"Unexpected input type for array: {type(object)}") out_array: Array = lax_internal._convert_element_type( out, dtype, weak_type=weak_type) if ndmin > ndim(out_array): out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array))) return out_array
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: try: dtypes.dtype(x) except TypeError: return np.asarray(x) else: return x
[docs] @util.implements(getattr(np, "astype", None), lax_description=""" This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :func:`numpy.astype` in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: del copy # unused in JAX if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") return lax.convert_element_type(x, dtype)
[docs] @util.implements(np.asarray, lax_description=_ARRAY_DOC) def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None) -> Array: # For copy=False, the array API specifies that we raise a ValueError if the input supports # the buffer protocol but a copy is required. Since array() supports the buffer protocol # via numpy, this is only the case when the default device is not 'cpu' if (copy is False and not isinstance(a, Array) and jax.default_backend() != 'cpu' and _supports_buffer_protocol(a)): raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " f"on backend={jax.default_backend()!r} with copy=False. " "Consider using copy=None or copy=True instead.") dtypes.check_user_dtype_supported(dtype, "asarray") if dtype is not None: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] return array(a, dtype=dtype, copy=bool(copy), order=order) # type: ignore
[docs] @util.implements(np.copy, lax_description=_ARRAY_DOC) def copy(a: ArrayLike, order: str | None = None) -> Array: util.check_arraylike("copy", a) return array(a, copy=True, order=order)
[docs] @util.implements(np.zeros_like) def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: shape = canonicalize_shape(shape) return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device))
[docs] @util.implements(np.ones_like) def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: shape = canonicalize_shape(shape) return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device))
[docs] @util.implements(np.empty_like, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing util.check_arraylike("empty_like", prototype) dtypes.check_user_dtype_supported(dtype, "empty_like") return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array: return arr if device is None else jax.device_put(arr, device) def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None: if isinstance(device, xc.Device): return SingleDeviceSharding(device) else: return device
[docs] @util.implements(np.full) def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "full") util.check_arraylike("full", fill_value) if ndim(fill_value) == 0: shape = canonicalize_shape(shape) return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device)) else: return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
[docs] @util.implements(np.full_like) def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing util.check_arraylike("full_like", 0, fill_value) else: util.check_arraylike("full_like", a, fill_value) dtypes.check_user_dtype_supported(dtype, "full_like") if shape is not None: shape = canonicalize_shape(shape) if ndim(fill_value) == 0: return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device)) else: shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type] return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
[docs] @util.implements(np.zeros) def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) dtypes.check_user_dtype_supported(dtype, "zeros") shape = canonicalize_shape(shape) return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
[docs] @util.implements(np.ones) def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) shape = canonicalize_shape(shape) dtypes.check_user_dtype_supported(dtype, "ones") return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
[docs] @util.implements(np.empty, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) dtypes.check_user_dtype_supported(dtype, "empty") return zeros(shape, dtype, device=device)
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore if isinstance(dtype, int) and isinstance(shape, int): return (f"Cannot interpret '{dtype}' as a data type." f"\n\nDid you accidentally write " f"`jax.numpy.{name}({shape}, {dtype})` " f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. " "with a single tuple argument for the shape?")
[docs] @util.implements(np.array_equal) def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: try: a1, a2 = asarray(a1), asarray(a2) except Exception as err: # TODO(jakevdp): Deprecated 2023-11-23; change to error. warnings.warn("Inputs to array_equal() cannot be coerced to array. " "Returning False; in the future this will raise an exception.\n" f"{err!r}", DeprecationWarning, stacklevel=2) return bool_(False) if shape(a1) != shape(a2): return bool_(False) eq = asarray(a1 == a2) if equal_nan: eq = ufuncs.logical_or(eq, ufuncs.logical_and(ufuncs.isnan(a1), ufuncs.isnan(a2))) return reductions.all(eq)
[docs] @util.implements(np.array_equiv) def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: try: a1, a2 = asarray(a1), asarray(a2) except Exception as err: # TODO(jakevdp): Deprecated 2023-11-23; change to error. warnings.warn("Inputs to array_equiv() cannot be coerced to array. " "Returning False; in the future this will raise an exception.\n" f"{err!r}", DeprecationWarning, stacklevel=2) return bool_(False) try: eq = ufuncs.equal(a1, a2) except ValueError: # shapes are not broadcastable return bool_(False) return reductions.all(eq)
# General np.from* style functions mostly delegate to numpy.
[docs] @util.implements(np.frombuffer) def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
[docs] def fromfile(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromfile. This function is left deliberately unimplemented because it may be non-pure and thus unsafe for use with JIT and other JAX transformations. Consider using ``jnp.asarray(np.fromfile(...))`` instead, although care should be taken if ``np.fromfile`` is used within jax transformations because of its potential side-effect of consuming the file object; for more information see `Common Gotchas: Pure Functions <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_. """ raise NotImplementedError( "jnp.fromfile() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) " "instead, although care should be taken if np.fromfile is used within a jax transformations " "because of its potential side-effect of consuming the file object; for more information see " "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")
[docs] def fromiter(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromiter. This function is left deliberately unimplemented because it may be non-pure and thus unsafe for use with JIT and other JAX transformations. Consider using ``jnp.asarray(np.fromiter(...))`` instead, although care should be taken if ``np.fromiter`` is used within jax transformations because of its potential side-effect of consuming the iterable object; for more information see `Common Gotchas: Pure Functions <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions>`_. """ raise NotImplementedError( "jnp.fromiter() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) " "instead, although care should be taken if np.fromiter is used within a jax transformations " "because of its potential side-effect of consuming the iterable object; for more information see " "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")
[docs] @util.implements(getattr(np, "from_dlpack", None), lax_description=""" .. note:: While JAX arrays are always immutable, dlpack buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a jax Array is constructed from a dlpack buffer and the buffer is later modified in-place, it may lead to undefined behavior when using the associated JAX array. """) def from_dlpack(x: Any) -> Array: from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top return from_dlpack(x.__dlpack__())
[docs] @util.implements(np.fromfunction) def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") for i in range(len(shape)): in_axes = [0 if i == j else None for j in range(len(shape))] function = jax.vmap(function, in_axes=tuple(in_axes[::-1])) return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
[docs] @util.implements(np.fromstring) def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
[docs] @util.implements(np.eye) def eye(N: DimSize, M: DimSize | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "eye") N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()") M_int = N_int if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()") if N_int < 0 or M_int < 0: raise ValueError(f"negative dimensions are not allowed, got {N} and {M}") k = operator.index(k) return lax_internal._eye(_jnp_dtype(dtype), (N_int, M_int), k)
[docs] @util.implements(np.identity) def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "identity") return eye(n, dtype=dtype)
[docs] @util.implements(np.arange,lax_description= """ .. note:: Using ``arange`` with the ``step`` argument can lead to precision errors, especially with lower-precision data types like ``fp8`` and ``bf16``. For more details, see the docstring of :func:`numpy.arange`. To avoid precision errors, consider using an expression like ``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision and then convert it to the desired lower precision. """) def arange(start: DimSize, stop: DimSize | None = None, step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") if not config.dynamic_shapes.value: util.check_arraylike("arange", start) if stop is None and step is None: start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'") else: start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'") util.check_arraylike_or_none("arange", None, stop, step) stop = core.concrete_or_error(None, stop, "It arose in the jnp.arange argument 'stop'") step = core.concrete_or_error(None, step, "It arose in the jnp.arange argument 'step'") start_name = "stop" if stop is None and step is None else "start" for name, val in [(start_name, start), ("stop", stop), ("step", step)]: if val is not None and np.ndim(val) != 0: raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}") if any(core.is_symbolic_dim(v) for v in (start, stop, step)): # Some dynamic shapes if stop is None and step is None: stop = start start = 0 step = 1 elif stop is not None and step is None: step = 1 return _arange_dynamic(start, stop, step, dtype or dtypes.canonicalize_dtype(np.int64)) if dtype is None: dtype = result_type(start, *(x for x in [stop, step] if x is not None)) dtype = _jnp_dtype(dtype) if stop is None and step is None: start_dtype = _dtype(start) if (not dtypes.issubdtype(start_dtype, np.integer) and not dtypes.issubdtype(start_dtype, dtypes.extended)): ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil start = ceil_(start).astype(int) # type: ignore return lax.iota(dtype, start) else: if step is None and start == 0 and stop is not None: return lax.iota(dtype, np.ceil(stop).astype(int)) return array(np.arange(start, stop=stop, step=step, dtype=dtype))
def _arange_dynamic( start: DimSize, stop: DimSize, step: DimSize, dtype: DTypeLike) -> Array: # Here if at least one of start, stop, step are dynamic. if any(not core.is_dim(v) for v in (start, stop, step)): raise ValueError( "In arange with non-constant arguments all of start, stop, and step " f"must be either dimension expressions or integers: start={start}, " f"stop={stop}, step={step}") # Must resolve statically if step is {<0, ==0, >0} try: if step == 0: raise ValueError("arange has step == 0") step_gt_0 = (step > 0) except core.InconclusiveDimensionOperation as e: raise core.InconclusiveDimensionOperation( f"In arange with non-constant arguments the step ({step}) must " + f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") gap = step if step_gt_0 else - step distance = (stop - start) if step_gt_0 else (start - stop) size = core.max_dim(0, distance + gap - 1) // gap return (array(start, dtype=dtype) + array(step, dtype=dtype) * lax.iota(dtype, size)) @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: Literal[False] = False, dtype: DTypeLike | None = None, axis: int = 0) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: bool, retstep: Literal[True], dtype: DTypeLike | None = None, axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, *, retstep: Literal[True], dtype: DTypeLike | None = None, axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0) -> Array | tuple[Array, Array]: ...
[docs] @util.implements(np.linspace) def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0) -> Array | tuple[Array, Array]: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis')) def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0) -> Array | tuple[Array, Array]: """Implementation of linspace differentiable in start and stop args.""" dtypes.check_user_dtype_supported(dtype, "linspace") if num < 0: raise ValueError(f"Number of samples, {num}, must be non-negative.") util.check_arraylike("linspace", start, stop) if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) broadcast_start = broadcast_to(start, bounds_shape) broadcast_stop = broadcast_to(stop, bounds_shape) axis = len(bounds_shape) + axis + 1 if axis < 0 else axis bounds_shape.insert(axis, 1) div = (num - 1) if endpoint else num if num > 1: delta: Array = lax.convert_element_type(stop - start, computation_dtype) / div iota_shape = [1,] * len(bounds_shape) iota_shape[axis] = div # This approach recovers the endpoints with float32 arithmetic, # but can lead to rounding errors for integer outputs. real_dtype = finfo(computation_dtype).dtype step = reshape(lax.iota(real_dtype, div), iota_shape) / div step = step.astype(computation_dtype) out = (reshape(broadcast_start, bounds_shape) * (1 - step) + reshape(broadcast_stop, bounds_shape) * step) if endpoint: out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], _canonicalize_axis(axis, out.ndim)) elif num == 1: delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype) out = reshape(broadcast_start, bounds_shape) else: # num == 0 degenerate case, match numpy behavior empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) empty_shape.insert(axis, 0) delta = asarray(nan, dtype=computation_dtype) out = reshape(array([], dtype=dtype), empty_shape) if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer): out = lax.floor(out) if retstep: return lax.convert_element_type(out, dtype), delta else: return lax.convert_element_type(out, dtype)
[docs] @util.implements(np.logspace) def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") return _logspace(start, stop, num, endpoint, base, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: """Implementation of logspace differentiable in start and stop args.""" dtypes.check_user_dtype_supported(dtype, "logspace") if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) util.check_arraylike("logspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) lin = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis) return lax.convert_element_type(ufuncs.power(base, lin), dtype)
[docs] @util.implements(np.geomspace) def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") return _geomspace(start, stop, num, endpoint, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: """Implementation of geomspace differentiable in start and stop args.""" dtypes.check_user_dtype_supported(dtype, "geomspace") if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) util.check_arraylike("geomspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) # follow the numpy geomspace convention for negative and complex endpoints signflip = 1 - (1 - ufuncs.sign(ufuncs.real(start))) * (1 - ufuncs.sign(ufuncs.real(stop))) // 2 signflip = signflip.astype(computation_dtype) res = signflip * logspace(ufuncs.log10(signflip * start), ufuncs.log10(signflip * stop), num, endpoint=endpoint, base=10.0, dtype=computation_dtype, axis=0) if axis != 0: res = moveaxis(res, 0, axis) return lax.convert_element_type(res, dtype)
[docs] @util.implements(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: util.check_arraylike("meshgrid", *xi) args = [asarray(x) for x in xi] if not copy: raise ValueError("jax.numpy.meshgrid only supports copy=True") if indexing not in ["xy", "ij"]: raise ValueError(f"Valid values for indexing are 'xy' and 'ij', got {indexing}") if any(a.ndim != 1 for a in args): raise ValueError("Arguments to jax.numpy.meshgrid must be 1D, got shapes " f"{[a.shape for a in args]}") if indexing == "xy" and len(args) >= 2: args[0], args[1] = args[1], args[0] shape = [1 if sparse else a.shape[0] for a in args] _a_shape = lambda i, a: [*shape[:i], a.shape[0], *shape[i + 1:]] if sparse else shape output = [lax.broadcast_in_dim(a, _a_shape(i, a), (i,)) for i, a, in enumerate(args)] if indexing == "xy" and len(args) >= 2: output[0], output[1] = output[1], output[0] return output
[docs] @custom_jvp @util.implements(np.i0) @jit def i0(x: ArrayLike) -> Array: x_arr, = util.promote_args_inexact("i0", x) if not issubdtype(x_arr.dtype, np.floating): raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") x_arr = lax.abs(x_arr) return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr))
@i0.defjvp def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out)
[docs] @util.implements(np.ix_) def ix_(*args: ArrayLike) -> tuple[Array, ...]: util.check_arraylike("ix", *args) n = len(args) output = [] for i, a in enumerate(args): a = asarray(a) if len(a.shape) != 1: msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}" raise ValueError(msg.format(a.shape)) if _dtype(a) == bool_: raise NotImplementedError( "Boolean arguments to jax.numpy.ix_ are not implemented") shape = [1] * n shape[i] = a.shape[0] if a.size == 0: # Numpy uses an integer index type for empty arrays. output.append(lax.full(shape, np.zeros((), np.intp))) else: output.append(lax.broadcast_in_dim(a, shape, (i,))) return tuple(output)
@overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: Literal[False] = False) -> Array: ... @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, *, sparse: Literal[True]) -> tuple[Array, ...]: ... @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: bool = False) -> Array | tuple[Array, ...]: ...
[docs] @util.implements(np.indices) def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: bool = False) -> Array | tuple[Array, ...]: dimensions = tuple( core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices") for d in dimensions) N = len(dimensions) output = [] s = dimensions for i, dim in enumerate(dimensions): idx = lax.iota(dtype, dim) if sparse: s = (1,)*i + (dim,) + (1,)*(N - i - 1) output.append(lax.broadcast_in_dim(idx, s, (i,))) if sparse: return tuple(output) return stack(output, 0) if output else array([], dtype=dtype)
_TOTAL_REPEAT_LENGTH_DOC = """\ JAX adds the optional `total_repeat_length` parameter which specifies the total number of repeat, and defaults to sum(repeats). It must be specified for repeat to be compilable. If `sum(repeats)` is larger than the specified `total_repeat_length` the remaining values will be discarded. In the case of `sum(repeats)` being smaller than the specified target length, the final value will be repeated. """
[docs] @util.implements(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: util.check_arraylike("repeat", a) core.is_dim(repeats) or util.check_arraylike("repeat", repeats) if axis is None: a = ravel(a) axis = 0 else: a = asarray(a) axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()") assert isinstance(axis, int) # to appease mypy if core.is_symbolic_dim(repeats): if total_repeat_length is not None: raise ValueError("jnp.repeat with a non-constant `repeats` is supported only " f"when `total_repeat_length` is None. ({repeats=} {total_repeat_length=})") # If total_repeat_length is not given, use a default. if total_repeat_length is None: repeats = core.concrete_or_error(None, repeats, "When jit-compiling jnp.repeat, the total number of repeats must be static. " "To fix this, either specify a static value for `repeats`, or pass a static " "value to `total_repeat_length`.") # Fast path for when repeats is a scalar. if np.ndim(repeats) == 0 and ndim(a) != 0: input_shape = shape(a) aux_axis = axis if axis < 0 else axis + 1 a = expand_dims(a, aux_axis) reps: list[DimSize] = [1] * len(shape(a)) reps[aux_axis] = repeats a = tile(a, reps) result_shape: list[DimSize] = list(input_shape) result_shape[axis] *= repeats return reshape(a, result_shape) repeats = np.ravel(repeats) if ndim(a) != 0: repeats = np.broadcast_to(repeats, [shape(a)[axis]]) total_repeat_length = np.sum(repeats) else: repeats = ravel(repeats) if ndim(a) != 0: repeats = broadcast_to(repeats, [shape(a)[axis]]) # Special case when a is a scalar. if ndim(a) == 0: if shape(repeats) == (1,): return full([total_repeat_length], a) else: raise ValueError('`repeat` with a scalar parameter `a` is only ' 'implemented for scalar values of the parameter `repeats`.') # Special case if total_repeat_length is zero. if total_repeat_length == 0: result_shape = list(shape(a)) result_shape[axis] = 0 return reshape(array([], dtype=_dtype(a)), result_shape) # If repeats is on a zero sized axis, then return the array. if shape(a)[axis] == 0: return asarray(a) # This implementation of repeat avoid having to instantiate a large. # intermediate tensor. # Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat. exclusive_repeats = roll(repeats, shift=1).at[0].set(0) # Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3] scatter_indices = reductions.cumsum(exclusive_repeats) # Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0] block_split_indicators = zeros([total_repeat_length], dtype=int32) block_split_indicators = block_split_indicators.at[scatter_indices].add(1) # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3] gather_indices = reductions.cumsum(block_split_indicators) - 1 return take(a, gather_indices, axis=axis)
[docs] @util.implements(np.tri) def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "tri") M = M if M is not None else N dtype = dtype or float32 return lax_internal._tri(dtype, (N, M), k)
[docs] @util.implements(np.tril) @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: util.check_arraylike("tril", m) m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") N, M = m_shape[-2:] mask = tri(N, M, k=k, dtype=bool) return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
[docs] @util.implements(np.triu, update_doc=False) @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: util.check_arraylike("triu", m) m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") N, M = m_shape[-2:] mask = tri(N, M, k=k - 1, dtype=bool) return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
[docs] @util.implements(np.trace, skip_params=['out']) @partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: util.check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) if dtype is None: dtype = _dtype(a) if issubdtype(dtype, integer): default_int = dtypes.canonicalize_dtype(int) if iinfo(dtype).bits < iinfo(default_int).bits: dtype = default_int a = moveaxis(a, (axis1, axis2), (-2, -1)) # Mask out the diagonal and reduce. a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool), a, zeros_like(a)) return reductions.sum(a, axis=(-2, -1), dtype=dtype)
def _wrap_indices_function(f): @util.implements(f, update_doc=False) def wrapper(*args, **kwargs): args = [core.concrete_or_error( None, arg, f"argument {i} of jnp.{f.__name__}()") for i, arg in enumerate(args)] kwargs = {key: core.concrete_or_error( None, val, f"argument '{key}' of jnp.{f.__name__}()") for key, val in kwargs.items()} return tuple(asarray(x) for x in f(*args, **kwargs)) return wrapper mask_indices = _wrap_indices_function(np.mask_indices) def _triu_size(n, m, k): if k < 0: return n * m - _triu_size(m, n, (1 - k)) elif k >= m: return 0 else: mk = min(n, m - k) return mk * (mk + 1) // 2 + mk * (m - k - mk)
[docs] @util.implements(np.triu_indices) def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices") i, j = nonzero(triu(ones((n, m)), k=k), size=_triu_size(n, m, k)) return i, j
[docs] @util.implements(np.tril_indices) def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices") i, j = nonzero(tril(ones((n, m)), k=k), size=_triu_size(m, n, -k)) return i, j
[docs] @util.implements(np.triu_indices_from) def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: arr_shape = shape(arr) return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
[docs] @util.implements(np.tril_indices_from) def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: arr_shape = shape(arr) return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
[docs] @util.implements(np.fill_diagonal, lax_description=""" The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal` adds the ``inplace`` parameter, which must be set to ``False`` by the user as a reminder of this API difference. """, extra_params=""" inplace : bool, default=True If left to its default value of True, JAX will raise an error. This is because the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays. """) def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: if inplace: raise NotImplementedError("JAX arrays are immutable, must use inplace=False") if wrap: raise NotImplementedError("wrap=True is not implemented, must use wrap=False") util.check_arraylike("fill_diagonal", a, val) a = asarray(a) val = asarray(val) if a.ndim < 2: raise ValueError("array must be at least 2-d") if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]): raise ValueError("All dimensions of input must be of equal length") n = min(a.shape) idx = diag_indices(n, a.ndim) return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
[docs] @util.implements(np.diag_indices) def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()") ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()") if n < 0: raise ValueError("n argument to diag_indices must be nonnegative, got {}" .format(n)) if ndim < 0: raise ValueError("ndim argument to diag_indices must be nonnegative, got {}" .format(ndim)) return (lax.iota(int_, n),) * ndim
[docs] @util.implements(np.diag_indices_from) def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: util.check_arraylike("diag_indices_from", arr) nd = ndim(arr) if not ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") s = shape(arr) if len(set(shape(arr))) != 1: raise ValueError("All dimensions of input must be of equal length") return diag_indices(s[0], ndim=nd)
[docs] @util.implements(np.diagonal, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: util.check_arraylike("diagonal", a) a_shape = shape(a) if ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()") a = moveaxis(a, (axis1, axis2), (-2, -1)) diag_size = max(0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0))) i = arange(diag_size) j = arange(abs(offset), abs(offset) + diag_size) return a[..., i, j] if offset >= 0 else a[..., j, i]
[docs] @util.implements(np.diag, lax_description=_ARRAY_VIEW_DOC) def diag(v: ArrayLike, k: int = 0) -> Array: return _diag(v, operator.index(k))
@partial(jit, static_argnames=('k',)) def _diag(v, k): util.check_arraylike("diag", v) v_shape = shape(v) if len(v_shape) == 1: zero = lambda x: lax.full_like(x, shape=(), fill_value=0) n = v_shape[0] + abs(k) v = lax.pad(v, zero(v), ((max(0, k), max(0, -k), 0),)) return where(eye(n, k=k, dtype=bool), v, zeros_like(v)) elif len(v_shape) == 2: return diagonal(v, offset=k) else: raise ValueError("diag input must be 1d or 2d") _SCALAR_VALUE_DOC = """\ This differs from np.diagflat for some scalar values of v, jax always returns a two-dimensional array, whereas numpy may return a scalar depending on the type of v. """
[docs] @util.implements(np.diagflat, lax_description=_SCALAR_VALUE_DOC) def diagflat(v: ArrayLike, k: int = 0) -> Array: util.check_arraylike("diagflat", v) v_ravel = ravel(v) v_length = len(v_ravel) adj_length = v_length + abs(k) res = zeros(adj_length*adj_length, dtype=v_ravel.dtype) i = arange(0, adj_length-abs(k)) if (k >= 0): fi = i+k+i*adj_length else: fi = i+(i-k)*adj_length res = res.at[fi].set(v_ravel) res = res.reshape(adj_length, adj_length) return res
[docs] @util.implements(np.trim_zeros) def trim_zeros(filt, trim='fb'): filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros()") nz = (filt == 0) if reductions.all(nz): return empty(0, _dtype(filt)) start = argmin(nz) if 'f' in trim.lower() else 0 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end]
def trim_zeros_tol(filt, tol, trim='fb'): filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros_tol()") nz = (ufuncs.abs(filt) < tol) if reductions.all(nz): return empty(0, _dtype(filt)) start = argmin(nz) if 'f' in trim.lower() else 0 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end]
[docs] @util.implements(np.append) @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None ) -> Array: if axis is None: return concatenate([ravel(arr), ravel(values)], 0) else: return concatenate([arr, values], axis=axis)
[docs] @util.implements(np.delete, lax_description=_dedent(""" delete() usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may specify ``assume_unique_indices=True`` to perform the operation in a manner that does not require static indices."""), extra_params=_dedent(""" assume_unique_indices : int, optional (default=False) In case of array-like integer (not boolean) indices, assume the indices are unique, and perform the deletion in a way that is compatible with JIT and other JAX transformations.""")) def delete( arr: ArrayLike, obj: ArrayLike | slice, axis: int | None = None, *, assume_unique_indices: bool = False, ) -> Array: util.check_arraylike("delete", arr) if axis is None: arr = ravel(arr) axis = 0 a = asarray(arr) axis = _canonicalize_axis(axis, a.ndim) # Case 1: obj is a static integer. try: obj = operator.index(obj) # type: ignore[arg-type] obj = _canonicalize_axis(obj, a.shape[axis]) except TypeError: pass else: idx = tuple(slice(None) for i in range(axis)) return concatenate([a[idx + (slice(0, obj),)], a[idx + (slice(obj + 1, None),)]], axis=axis) # Case 2: obj is a static slice. if isinstance(obj, slice): obj = arange(a.shape[axis])[obj] assume_unique_indices = True # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. util.check_arraylike("delete", a, obj) # Case 3a: unique integer indices; delete in a JIT-compatible way if issubdtype(_dtype(obj), integer) and assume_unique_indices: obj = asarray(obj).ravel() obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis]) obj = sort(obj) obj -= arange(len(obj)) # type: ignore[arg-type,operator] i = arange(a.shape[axis] - obj.size) i += (i[None, :] >= obj[:, None]).sum(0) return a[(slice(None),) * axis + (i,)] # Case 3b: non-unique indices: must be static. obj_array = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()") if issubdtype(obj_array.dtype, integer): # TODO(jakevdp): in theory this could be done dynamically if obj has no duplicates, # but this would require the complement of lax.gather. mask = np.ones(a.shape[axis], dtype=bool) mask[obj_array] = False elif obj_array.dtype == bool: if obj_array.shape != (a.shape[axis],): raise ValueError("np.delete(arr, obj): for boolean indices, obj must be one-dimensional " "with length matching specified axis.") mask = ~obj_array else: raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj_array.dtype}; must be integer or bool.") return a[tuple(slice(None) for i in range(axis)) + (mask,)]
[docs] @util.implements(np.insert) def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = None) -> Array: util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values) a = asarray(arr) values_arr = asarray(values) if axis is None: a = ravel(a) axis = 0 axis = core.concrete_or_error(None, axis, "axis argument of jnp.insert()") axis = _canonicalize_axis(axis, a.ndim) if isinstance(obj, slice): indices = arange(*obj.indices(a.shape[axis])) else: indices = asarray(obj) if indices.ndim > 1: raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional " f"array, or a scalar; got {obj}") if not np.issubdtype(indices.dtype, np.integer): if indices.size == 0 and not isinstance(obj, Array): indices = indices.astype(int) else: # Note: np.insert allows boolean inputs but the behavior is deprecated. raise ValueError("jnp.insert(): index array must be " f"integer typed; got {obj}") values_arr = array(values_arr, ndmin=a.ndim, dtype=a.dtype, copy=False) if indices.size == 1: index = ravel(indices)[0] if indices.ndim == 0: values_arr = moveaxis(values_arr, 0, axis) indices = full(values_arr.shape[axis], index) n_input = a.shape[axis] n_insert = broadcast_shapes(indices.shape, (values_arr.shape[axis],))[0] out_shape = list(a.shape) out_shape[axis] += n_insert out = zeros_like(a, shape=tuple(out_shape)) indices = where(indices < 0, indices + n_input, indices) indices = clip(indices, 0, n_input) values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype)) arr_mask = ones(n_input + n_insert, dtype=bool).at[values_ind].set(False) arr_ind = where(arr_mask, size=n_input)[0] out = out.at[(slice(None),) * axis + (values_ind,)].set(values_arr) out = out.at[(slice(None),) * axis + (arr_ind,)].set(a) return out
[docs] @util.implements(np.apply_along_axis) def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("apply_along_axis", arr, emit_warning=True) num_dims = ndim(arr) axis = _canonicalize_axis(axis, num_dims) func = lambda arr: func1d(arr, *args, **kwargs) for i in range(1, num_dims - axis): func = jax.vmap(func, in_axes=i, out_axes=-1) for i in range(axis): func = jax.vmap(func, in_axes=0, out_axes=0) return func(arr)
[docs] @util.implements(np.apply_over_axes) def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("apply_over_axes", a, emit_warning=True) a_arr = asarray(a) for axis in axes: b = func(a_arr, axis) if b.ndim == a_arr.ndim: a_arr = b elif b.ndim == a_arr.ndim - 1: a_arr = expand_dims(b, axis) else: raise ValueError("function is not returning an array of the correct shape") return a_arr
### Tensor contraction operations _DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """ preferred_element_type : dtype, optional If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes. """
[docs] @util.implements(np.dot, lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: util.check_arraylike("dot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "dot") a, b = asarray(a), asarray(b) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) else: output_weak_type = False batch_dims = ((), ()) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: # TODO(jakevdp): lower this case to dot_general as well? # Currently, doing so causes issues in remat tests due to #16805 if preferred_element_type is not None: a = a.astype(preferred_element_type) b = b.astype(preferred_element_type) result = lax.mul(a, b) else: if b_ndim == 1: contract_dims = ((a_ndim - 1,), (0,)) else: contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), precision=precision, preferred_element_type=preferred_element_type) return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
[docs] @util.implements(np.matmul, module='numpy', lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: """Matrix Multiply.""" util.check_arraylike("matmul", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "matmul") a, b = asarray(a), asarray(b) for i, x in enumerate((a, b)): if ndim(x) < 1: msg = (f"matmul input operand {i} must have ndim at least 1, " f"but it has ndim {ndim(x)}") raise ValueError(msg) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) else: output_weak_type = False a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1) a_batch_dims: tuple[int | None, ...] = shape(a)[:-2] if a_is_mat else () b_batch_dims: tuple[int | None, ...] = shape(b)[:-2] if b_is_mat else () num_batch_dims = max(len(a_batch_dims), len(b_batch_dims)) a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims # Dimensions to squeeze from the inputs. a_squeeze: list[int] = [] b_squeeze: list[int] = [] # Positions of batch dimensions in squeezed inputs. a_batch = [] b_batch = [] # Desired index in final output of each kind of dimension, in the order that # lax.dot_general will emit them. idx_batch: list[int] = [] idx_a_other: list[int] = [] # other = non-batch, non-contracting. idx_b_other: list[int] = [] for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)): if ba is None: idx_b_other.append(i) elif bb is None: idx_a_other.append(i) elif core.definitely_equal(ba, 1): idx_b_other.append(i) a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze)) elif core.definitely_equal(bb, 1): idx_a_other.append(i) b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze)) elif core.definitely_equal(ba, bb): a_batch.append(len(idx_batch) + len(idx_a_other)) b_batch.append(len(idx_batch) + len(idx_b_other)) idx_batch.append(i) else: raise ValueError("Incompatible shapes for matmul arguments: {} and {}" .format(shape(a), shape(b))) if a_is_mat: idx_a_other.append(num_batch_dims) if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat) perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other])) a = lax.squeeze(a, tuple(a_squeeze)) b = lax.squeeze(b, tuple(b_squeeze)) out = lax.dot_general( a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)), precision=precision, preferred_element_type=preferred_element_type) result = lax.transpose(out, perm) return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
[docs] @util.implements(np.vdot, lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: util.check_arraylike("vdot", a, b) if issubdtype(_dtype(a), complexfloating): a = ufuncs.conj(a) return dot(ravel(a), ravel(b), precision=precision, preferred_element_type=preferred_element_type)
[docs] @util.implements(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: util.check_arraylike("jnp.vecdot", x1, x2) x1_arr, x2_arr = asarray(x1), asarray(x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) return vectorize(partial(vdot, precision=precision, preferred_element_type=preferred_element_type), signature="(n),(n)->()")(x1_arr, x2_arr)
[docs] @util.implements(np.tensordot, lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: util.check_arraylike("tensordot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "tensordot") a, b = asarray(a), asarray(b) a_ndim = ndim(a) b_ndim = ndim(b) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) else: output_weak_type = False if type(axes) is int: if axes > min(a_ndim, b_ndim): msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})" raise TypeError(msg.format(axes, a.shape, b.shape)) contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes)) elif isinstance(axes, (tuple, list)) and len(axes) == 2: ax1, ax2 = axes if type(ax1) == type(ax2) == int: contracting_dims = ((_canonicalize_axis(ax1, a_ndim),), (_canonicalize_axis(ax2, b_ndim),)) elif isinstance(ax1, (tuple, list)) and isinstance(ax2, (tuple, list)): if len(ax1) != len(ax2): msg = "tensordot requires axes lists to have equal length, got {} and {}." raise TypeError(msg.format(ax1, ax2)) contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1), tuple(_canonicalize_axis(i, b_ndim) for i in ax2)) else: msg = ("tensordot requires both axes lists to be either ints, tuples or " "lists, got {} and {}") raise TypeError(msg.format(ax1, ax2)) else: msg = ("tensordot axes argument must be an int, a pair of ints, or a pair " "of lists/tuples of ints.") raise TypeError(msg) result = lax.dot_general(a, b, (contracting_dims, ((), ())), precision=precision, preferred_element_type=preferred_element_type) return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
_EINSUM_DOC = _PRECISION_DOC + """\ A tuple ``precision`` does not necessarily map to multiple arguments of ``einsum()``; rather, the specified ``precision`` is forwarded to each ``dot_general`` call used in the implementation. :func:`jax.numpy.einsum` also differs from :func:`numpy.einsum` in that the ``optimize`` keyword defaults to ``"optimal"`` rather than ``False``. """ @overload def einsum( subscript: str, /, *operands: ArrayLike, out: None = None, optimize: str = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: ... @overload def einsum( arr: ArrayLike, axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], out: None = None, optimize: str = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: ...
[docs] @util.implements(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out']) def einsum( subscripts, /, *operands, out: None = None, optimize: str = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: operands = (subscripts, *operands) if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") spec = operands[0] if isinstance(operands[0], str) else None optimize = 'optimal' if optimize is True else optimize # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) for d in np.shape(op) if not core.is_constant_dim(d) } if not non_constant_dim_types: contract_path = opt_einsum.contract_path else: ty = next(iter(non_constant_dim_types)) contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) # using einsum_call=True here is an internal api for opt_einsum... sorry operands, contractions = contract_path( *operands, einsum_call=True, use_blas=True, optimize=optimize) contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) if spec is not None: einsum = jax.named_call(einsum, name=spec) return einsum(operands, contractions, precision, # type: ignore[operator] preferred_element_type, _dot_general)
# Enable other modules to override einsum_contact_path. # Indexed by the type of the non constant dimension _poly_einsum_handlers = {} # type: ignore def _default_poly_einsum_handler(*operands, **kwargs): dummy = collections.namedtuple('dummy', ['shape', 'dtype']) dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype) if hasattr(x, 'dtype') else x for x in operands] mapping = {id(d): i for i, d in enumerate(dummies)} out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs) contract_operands = [operands[mapping[id(d)]] for d in out_dummies] return contract_operands, contractions
[docs] @util.implements(np.einsum_path) def einsum_path(subscripts, *operands, optimize='greedy'): # using einsum_call=True here is an internal api for opt_einsum return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
def _removechars(s, chars): return s.translate(str.maketrans(dict.fromkeys(chars))) def _einsum( operands: Sequence, contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]], precision, preferred_element_type, _dot_general=lax.dot_general, ): dtypes.check_user_dtype_supported(preferred_element_type, "einsum") operands = list(map(asarray, operands)) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True) else: output_weak_type = False def sum(x, axes): if dtypes.result_type(x, preferred_element_type) != x.dtype: x = x.astype(preferred_element_type) return lax.reduce(x, np.array(0, x.dtype), lax.add if x.dtype != bool_ else lax.bitwise_or, axes) def sum_uniques(operand, names, uniques): if uniques: axes = [names.index(name) for name in uniques] operand = sum(operand, axes) names = _removechars(names, uniques) return operand, names def sum_repeats(operand, names, counts, keep_names): for name, count in counts.items(): if count > 1: axes = [i for i, n in enumerate(names) if n == name] eye = lax_internal._delta(np.dtype('bool'), operand.shape, axes) operand = lax.select(eye, operand, zeros_like(operand)) if name not in keep_names: operand = sum(operand, axes) names = names.replace(name, '') else: operand = sum(operand, axes[:-1]) names = names.replace(name, '', count - 1) return operand, names def filter_singleton_dims(operand, names, other_shape, other_names): eq = core.definitely_equal keep = [not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1) for i, j in enumerate(map(other_names.find, names))] sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim))) return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes) for operand_indices, contracted_names_set, einstr in contractions: contracted_names = sorted(contracted_names_set) input_str, result_names = einstr.split('->') input_names = input_str.split(',') # switch on the number of operands to be processed in this loop iteration. # every case here sets 'operand' and 'names'. if len(operand_indices) == 1: operand = operands.pop(operand_indices[0]) names, = input_names counts = collections.Counter(names) # sum out unique contracted indices with a single reduce-sum uniques = [name for name in contracted_names if counts[name] == 1] operand, names = sum_uniques(operand, names, uniques) # for every repeated index, do a contraction against an identity matrix operand, names = sum_repeats(operand, names, counts, result_names) elif len(operand_indices) == 2: lhs, rhs = map(operands.pop, operand_indices) lhs_names, rhs_names = input_names # handle cases where one side of a contracting or batch dimension is 1 # but its counterpart is not. lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs), rhs_names) rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs), lhs_names) lhs_counts = collections.Counter(lhs_names) rhs_counts = collections.Counter(rhs_names) # sum out unique contracted indices in lhs and rhs lhs_uniques = [name for name in contracted_names if lhs_counts[name] == 1 and rhs_counts[name] == 0] lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques) rhs_uniques = [name for name in contracted_names if rhs_counts[name] == 1 and lhs_counts[name] == 0] rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques) # for every repeated index, contract against an identity matrix lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts, result_names + rhs_names) rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts, result_names + lhs_names) lhs_or_rhs_names = set(lhs_names) | set(rhs_names) contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names] lhs_and_rhs_names = set(lhs_names) & set(rhs_names) batch_names = [x for x in result_names if x in lhs_and_rhs_names] lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n)) for n in batch_names) # NOTE(mattjj): this can fail non-deterministically in python3, maybe # due to opt_einsum assert config.dynamic_shapes.value or all( name in lhs_names and name in rhs_names and lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)] for name in contracted_names), ( "Incompatible reduction dimensions: " f"lhs.shape={lhs.shape} lhs_names={lhs_names} " f"rhs.shape={rhs.shape} rhs_names={rhs_names}") # contract using dot_general batch_names_str = ''.join(batch_names) lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n)) for n in contracted_names) deleted_names = batch_names_str + ''.join(contracted_names) remaining_lhs_names = _removechars(lhs_names, deleted_names) remaining_rhs_names = _removechars(rhs_names, deleted_names) # Try both orders of lhs and rhs, in the hope that one of them means we # don't need an explicit transpose. opt_einsum likes to contract from # right to left, so we expect (rhs,lhs) to have the best chance of not # needing a transpose. names = batch_names_str + remaining_rhs_names + remaining_lhs_names if names == result_names: dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch)) operand = _dot_general(rhs, lhs, dimension_numbers, precision, preferred_element_type=preferred_element_type) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) operand = _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type=preferred_element_type) else: raise NotImplementedError # if this is actually reachable, open an issue! # the resulting 'operand' with axis labels 'names' should be a permutation # of the desired result assert len(names) == len(result_names) == len(set(names)) assert set(names) == set(result_names) if names != result_names: perm = tuple(names.index(name) for name in result_names) operand = lax.transpose(operand, perm) operands.append(operand) # used in next iteration return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)
[docs] @util.implements(np.inner, lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DType | None = None, ) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("inner", a, b, emit_warning=True) if ndim(a) == 0 or ndim(b) == 0: a = asarray(a, dtype=preferred_element_type) b = asarray(b, dtype=preferred_element_type) return a * b return tensordot(a, b, (-1, -1), precision=precision, preferred_element_type=preferred_element_type)
[docs] @util.implements(np.outer, skip_params=['out']) @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("outer", a, b, emit_warning=True) a, b = util.promote_dtypes(a, b) return ravel(a)[:, None] * ravel(b)[None, :]
[docs] @util.implements(np.cross) @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): # TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here. # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("cross", a, b, emit_warning=True) if axis is not None: axisa = axis axisb = axis axisc = axis a = moveaxis(a, axisa, -1) b = moveaxis(b, axisb, -1) if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): raise ValueError("Dimension must be either 2 or 3 for cross product") if a.shape[-1] == 2 and b.shape[-1] == 2: return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0] a0 = a[..., 0] a1 = a[..., 1] a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0) b0 = b[..., 0] b1 = b[..., 1] b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0) c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0]) return moveaxis(c, 0, axisc)
[docs] @util.implements(np.kron) @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("kron", a, b, emit_warning=True) a, b = util.promote_dtypes(a, b) if ndim(a) < ndim(b): a = expand_dims(a, range(ndim(b) - ndim(a))) elif ndim(b) < ndim(a): b = expand_dims(b, range(ndim(a) - ndim(b))) a_reshaped = expand_dims(a, range(1, 2 * ndim(a), 2)) b_reshaped = expand_dims(b, range(0, 2 * ndim(b), 2)) out_shape = tuple(np.multiply(shape(a), shape(b))) return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
[docs] @util.implements(np.vander) @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False ) -> Array: util.check_arraylike("vander", x) x = asarray(x) if x.ndim != 1: raise ValueError("x must be a one-dimensional array") N = x.shape[0] if N is None else core.concrete_or_error( operator.index, N, "'N' argument of jnp.vander()") if N < 0: raise ValueError("N must be nonnegative") iota = lax.iota(x.dtype, N) if not increasing: iota = lax.sub(_lax_const(iota, N - 1), iota) return ufuncs.power(x[..., None], expand_dims(iota, tuple(range(x.ndim))))
### Misc _ARGWHERE_DOC = """\ Because the size of the output of ``argwhere`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument, which specifies the size of the leading dimension of the output - it must be specified statically for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified, the indices of the first ``size`` True elements will be returned; if there are fewer nonzero elements than `size` indicates, the index arrays will be zero-padded. """
[docs] @util.implements(np.argwhere, lax_description=_dedent(""" Because the size of the output of ``argwhere`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which must be specified statically for ``jnp.argwhere`` to be used within some of JAX's transformations."""), extra_params=_dedent(""" size : int, optional If specified, the indices of the first ``size`` True elements will be returned. If there are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``. fill_value : array_like, optional When ``size`` is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with ``fill_value``, which defaults to zero.""")) def argwhere( a: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None, ) -> Array: result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value))) if ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) return result.reshape(result.shape[0], ndim(a))
[docs] @util.implements(np.argmax, skip_params=['out']) def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: util.check_arraylike("argmax", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.") return _argmax(asarray(a), None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: if axis is None: dims = list(range(ndim(a))) a = ravel(a) axis = 0 else: dims = [axis] if a.shape[axis] == 0: raise ValueError("attempt to get argmax of an empty sequence") result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_)) return expand_dims(result, dims) if keepdims else result
[docs] @util.implements(np.argmin, skip_params=['out']) def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: util.check_arraylike("argmin", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.") return _argmin(asarray(a), None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: if axis is None: dims = list(range(ndim(a))) a = ravel(a) axis = 0 else: dims = [axis] if a.shape[axis] == 0: raise ValueError("attempt to get argmin of an empty sequence") result = lax.argmin(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_)) return expand_dims(result, dims) if keepdims else result _NANARG_DOC = """\ Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise an error. """
[docs] @util.implements(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) def nanargmax( a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None, ) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims')) def _nanargmax(a, axis: int | None = None, keepdims: bool = False): util.check_arraylike("nanargmax", a) if not issubdtype(_dtype(a), inexact): return argmax(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) a = where(nan_mask, -inf, a) res = argmax(a, axis=axis, keepdims=keepdims) return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
[docs] @util.implements(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) def nanargmin( a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None, ) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@partial(jit, static_argnames=('axis', 'keepdims')) def _nanargmin(a, axis: int | None = None, keepdims : bool = False): util.check_arraylike("nanargmin", a) if not issubdtype(_dtype(a), inexact): return argmin(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) a = where(nan_mask, inf, a) res = argmin(a, axis=axis, keepdims=keepdims) return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
[docs] @util.implements(np.sort, extra_params=""" kind : deprecated; specify sort algorithm using stable=True or stable=False order : not supported stable : bool, default=True Specify whether to use a stable sort. descending : bool, default=False Specify whether to do a descending sort. """) @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, axis: int | None = -1, kind: str | None = None, order: None = None, *, stable: bool = True, descending: bool = False, ) -> Array: util.check_arraylike("sort", a) if kind is not None: # Deprecated 2024-01-05 warnings.warn("The 'kind' argument to sort has no effect, and is deprecated. " "Use stable=True or stable=False to specify sort stability.", category=DeprecationWarning, stacklevel=2) if order is not None: raise ValueError("'order' argument to sort is not supported.") if axis is None: arr = ravel(a) axis = 0 else: arr = asarray(a) dimension = _canonicalize_axis(axis, arr.ndim) result = lax.sort(arr, dimension=dimension, is_stable=stable) return lax.rev(result, dimensions=[dimension]) if descending else result
[docs] @util.implements(np.sort_complex) @jit def sort_complex(a: ArrayLike) -> Array: util.check_arraylike("sort_complex", a) a = lax.sort(asarray(a), dimension=0) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
[docs] @util.implements(np.lexsort) @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: key_tuple = tuple(keys) # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("lexsort", *key_tuple, emit_warning=True) key_arrays = tuple(asarray(k) for k in key_tuple) if len(key_arrays) == 0: raise TypeError("need sequence of keys with len > 0 in lexsort") if len({shape(key) for key in key_arrays}) > 1: raise ValueError("all keys need to be the same shape") if ndim(key_arrays[0]) == 0: return array(0, dtype=dtypes.canonicalize_dtype(int_)) axis = _canonicalize_axis(axis, ndim(key_arrays[0])) use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31) iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, shape(key_arrays[0]), axis) return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1]
[docs] @util.implements(np.argsort, extra_params=""" kind : deprecated; specify sort algorithm using stable=True or stable=False order : not supported stable : bool, default=True Specify whether to use a stable sort. descending : bool, default=False Specify whether to do a descending sort. """) @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, axis: int | None = -1, kind: str | None = None, order: None = None, *, stable: bool = True, descending: bool = False, ) -> Array: util.check_arraylike("argsort", a) arr = asarray(a) if kind is not None: # Deprecated 2024-01-05 warnings.warn("The 'kind' argument to argsort has no effect, and is deprecated. " "Use stable=True or stable=False to specify sort stability.", category=DeprecationWarning, stacklevel=2) if order is not None: raise ValueError("'order' argument to argsort is not supported.") if axis is None: arr = ravel(arr) axis = 0 else: arr = asarray(a) dimension = _canonicalize_axis(axis, arr.ndim) use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31) iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, dimension) # For stable descending sort, we reverse the array and indices to ensure that # duplicates remain in their original order when the final indices are reversed. # For non-stable descending sort, we can avoid these extra operations. if descending and stable: arr = lax.rev(arr, dimensions=[dimension]) iota = lax.rev(iota, dimensions=[dimension]) _, indices = lax.sort_key_val(arr, iota, dimension=dimension, is_stable=stable) return lax.rev(indices, dimensions=[dimension]) if descending else indices
[docs] @util.implements(np.partition, lax_description=""" The JAX version requires the ``kth`` argument to be a static integer rather than a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If you're only accessing the top or bottom k values of the output, it may be more efficient to call :func:`jax.lax.top_k` directly. The JAX version differs from the NumPy version in the treatment of NaN entries; NaNs which have the negative bit set are sorted to the beginning of the array. """) @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # TODO(jakevdp): handle NaN values like numpy. util.check_arraylike("partition", a) arr = asarray(a) if issubdtype(arr.dtype, np.complexfloating): raise NotImplementedError("jnp.partition for complex dtype is not implemented.") axis = _canonicalize_axis(axis, arr.ndim) kth = _canonicalize_axis(kth, arr.shape[axis]) arr = swapaxes(arr, axis, -1) bottom = -lax.top_k(-arr, kth + 1)[0] top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0] out = lax.concatenate([bottom, top], dimension=arr.ndim - 1) return swapaxes(out, -1, axis)
[docs] @util.implements(np.argpartition, lax_description=""" The JAX version requires the ``kth`` argument to be a static integer rather than a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If you're only accessing the top or bottom k values of the output, it may be more efficient to call :func:`jax.lax.top_k` directly. The JAX version differs from the NumPy version in the treatment of NaN entries; NaNs which have the negative bit set are sorted to the beginning of the array. """) @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # TODO(jakevdp): handle NaN values like numpy. util.check_arraylike("partition", a) arr = asarray(a) if issubdtype(arr.dtype, np.complexfloating): raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.") axis = _canonicalize_axis(axis, arr.ndim) kth = _canonicalize_axis(kth, arr.shape[axis]) arr = swapaxes(arr, axis, -1) bottom_ind = lax.top_k(-arr, kth + 1)[1] # To avoid issues with duplicate values, we compute the top indices via a proxy set_to_zero = lambda a, i: a.at[i].set(0) for _ in range(arr.ndim - 1): set_to_zero = jax.vmap(set_to_zero) proxy = set_to_zero(ones(arr.shape), bottom_ind) top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1] out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1) return swapaxes(out, -1, axis)
@partial(jit, static_argnums=(2,)) def _roll_dynamic(a: Array, shift: Array, axis: Sequence[int]) -> Array: b_shape = lax.broadcast_shapes(shift.shape, np.shape(axis)) if len(b_shape) != 1: msg = "'shift' and 'axis' arguments to roll must be scalars or 1D arrays" raise ValueError(msg) for x, i in zip(broadcast_to(shift, b_shape), np.broadcast_to(axis, b_shape)): a_shape_i = array(a.shape[i], dtype=np.int32) x = ufuncs.remainder(lax.convert_element_type(x, np.int32), lax.max(a_shape_i, np.int32(1))) a_concat = lax.concatenate((a, a), i) a = lax.dynamic_slice_in_dim(a_concat, a_shape_i - x, a.shape[i], axis=i) return a @partial(jit, static_argnums=(1, 2)) def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: for ax, s in zip(*np.broadcast_arrays(axis, shift)): if a.shape[ax] == 0: continue i = (-s) % a.shape[ax] a = lax.concatenate([lax.slice_in_dim(a, i, a.shape[ax], axis=ax), lax.slice_in_dim(a, 0, i, axis=ax)], dimension=ax) return a
[docs] @util.implements(np.roll) def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: util.check_arraylike("roll", a) arr = asarray(a) if axis is None: return roll(arr.ravel(), shift, 0).reshape(arr.shape) axis = _ensure_index_tuple(axis) axis = tuple(_canonicalize_axis(ax, arr.ndim) for ax in axis) try: shift = _ensure_index_tuple(shift) except TypeError: return _roll_dynamic(arr, asarray(shift), axis) else: return _roll_static(arr, shift, axis)
[docs] @util.implements(np.rollaxis, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: util.check_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") a_ndim = ndim(a) axis = _canonicalize_axis(axis, a_ndim) if not (-a_ndim <= start <= a_ndim): raise ValueError(f"{start=} must satisfy {-a_ndim}<=start<={a_ndim}") if start < 0: start += a_ndim if start > axis: start -= 1 return moveaxis(a, axis, start)
[docs] @util.implements(np.packbits) @partial(jit, static_argnames=('axis', 'bitorder')) def packbits( a: ArrayLike, axis: int | None = None, bitorder: str = "big" ) -> Array: util.check_arraylike("packbits", a) arr = asarray(a) if not (issubdtype(arr.dtype, integer) or issubdtype(arr.dtype, bool_)): raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") arr = lax.gt(arr, _lax_const(a, 0)).astype('uint8') bits = arange(8, dtype='uint8') if bitorder == 'big': bits = bits[::-1] if axis is None: arr = ravel(arr) axis = 0 arr = swapaxes(arr, axis, -1) remainder = arr.shape[-1] % 8 if remainder: arr = lax.pad(arr, np.uint8(0), (arr.ndim - 1) * [(0, 0, 0)] + [(0, 8 - remainder, 0)]) arr = arr.reshape(arr.shape[:-1] + (arr.shape[-1] // 8, 8)) bits = expand_dims(bits, tuple(range(arr.ndim - 1))) packed = (arr << bits).sum(-1).astype('uint8') return swapaxes(packed, axis, -1)
[docs] @util.implements(np.unpackbits) @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, axis: int | None = None, count: int | None = None, bitorder: str = "big", ) -> Array: util.check_arraylike("unpackbits", a) arr = asarray(a) if _dtype(a) != uint8: raise TypeError("Expected an input array of unsigned byte data type") if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") bits = asarray(1) << arange(8, dtype='uint8') if bitorder == 'big': bits = bits[::-1] if axis is None: arr = ravel(arr) axis = 0 arr = swapaxes(arr, axis, -1) unpacked = ((arr[..., None] & expand_dims(bits, tuple(range(arr.ndim)))) > 0).astype('uint8') unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,)) if count is not None: if count > unpacked.shape[-1]: unpacked = pad(unpacked, [(0, 0)] * (unpacked.ndim - 1) + [(0, count - unpacked.shape[-1])]) else: unpacked = unpacked[..., :count] return swapaxes(unpacked, axis, -1)
[docs] @util.implements(np.take, skip_params=['out'], lax_description=""" By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the ``mode`` parameter (see below). """, extra_params=""" mode : string, default="fill" Out-of-bounds indexing mode. The default mode="fill" returns invalid values (e.g. NaN) for out-of bounds indices (see also ``fill_value`` below). For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`. fill_value : optional The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans. unique_indices : bool, default=False If True, the implementation will assume that the indices are unique, which can result in more efficient execution on some backends. indices_are_sorted : bool, default=False If True, the implementation will assume that the indices are sorted in ascending order, which can lead to more efficient execution on some backends. """) def take( a: ArrayLike, indices: ArrayLike, axis: int | None = None, out: None = None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value: ArrayLike | None = None, ) -> Array: return _take(a, indices, None if axis is None else operator.index(axis), out, mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value)
@partial(jit, static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value')) def _take(a, indices, axis: int | None = None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.take is not supported.") util.check_arraylike("take", a, indices) a = asarray(a) indices = asarray(indices) if axis is None: a = ravel(a) axis_idx = 0 else: axis_idx = _canonicalize_axis(axis, ndim(a)) if mode is None or mode == "fill": gather_mode = lax.GatherScatterMode.FILL_OR_DROP # lax.gather() does not support negative indices, so we wrap them here indices = where(indices < 0, indices + a.shape[axis_idx], indices) elif mode == "raise": # TODO(phawkins): we have no way to report out of bounds errors yet. raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") elif mode == "wrap": indices = ufuncs.mod(indices, _lax_const(indices, a.shape[axis_idx])) gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS elif mode == "clip": gather_mode = lax.GatherScatterMode.CLIP else: raise ValueError(f"Invalid mode '{mode}' for np.take") index_dims = len(shape(indices)) slice_sizes = list(shape(a)) if slice_sizes[axis_idx] == 0: if indices.size != 0: raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.") return a if indices.size == 0: out_shape = (slice_sizes[:axis_idx] + list(indices.shape) + slice_sizes[axis_idx + 1:]) return full_like(a, 0, shape=out_shape) slice_sizes[axis_idx] = 1 dnums = lax.GatherDimensionNumbers( offset_dims=tuple( list(range(axis_idx)) + list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))), collapsed_slice_dims=(axis_idx,), start_index_map=(axis_idx,)) return lax.gather(a, indices[..., None], dimension_numbers=dnums, slice_sizes=tuple(slice_sizes), mode=gather_mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value) def _normalize_index(index, axis_size): """Normalizes an index value in the range [-N, N) to the range [0, N).""" if issubdtype(_dtype(index), np.unsignedinteger): return index if core.is_constant_dim(axis_size): axis_size_val = _lax_const(index, axis_size) else: axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size), _dtype(index)) if isinstance(index, (int, np.integer)): return lax.add(index, axis_size_val) if index < 0 else index else: return lax.select(index < 0, lax.add(index, axis_size_val), index) TAKE_ALONG_AXIS_DOC = """ Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes an optional ``mode`` parameter controlling how out-of-bounds indices should be handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``). See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds indexing in JAX. """
[docs] @util.implements(np.take_along_axis, update_doc=False, lax_description=TAKE_ALONG_AXIS_DOC) @partial(jit, static_argnames=('axis', 'mode')) def take_along_axis( arr: ArrayLike, indices: ArrayLike, axis: int | None, mode: str | lax.GatherScatterMode | None = None, ) -> Array: util.check_arraylike("take_along_axis", arr, indices) a = asarray(arr) index_dtype = dtypes.dtype(indices) idx_shape = shape(indices) if not dtypes.issubdtype(index_dtype, integer): raise TypeError("take_along_axis indices must be of integer type, got " f"{index_dtype}") if axis is None: if ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}" raise ValueError(msg.format(idx_shape)) return take_along_axis(a.ravel(), indices, 0) rank = ndim(arr) if rank != ndim(indices): msg = "indices and arr must have the same number of dimensions; {} vs. {}" raise ValueError(msg.format(ndim(indices), a.ndim)) axis_int = _canonicalize_axis(axis, rank) def replace(tup, val): lst = list(tup) lst[axis_int] = val return tuple(lst) use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape) index_dtype = dtype(int64 if use_64bit_index else int32) indices = lax.convert_element_type(indices, index_dtype) axis_size = a.shape[axis_int] arr_shape = replace(a.shape, 1) out_shape = lax.broadcast_shapes(idx_shape, arr_shape) if axis_size == 0: return zeros(out_shape, a.dtype) index_dims = [i for i, idx in enumerate(idx_shape) if i == axis_int or not core.definitely_equal(idx, 1)] gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,) gather_indices = [] slice_sizes = [] offset_dims = [] start_index_map = [] collapsed_slice_dims = [] j = 0 for i in range(rank): if i == axis_int: indices = _normalize_index(indices, axis_size) gather_indices.append(lax.reshape(indices, gather_index_shape)) slice_sizes.append(1) start_index_map.append(i) collapsed_slice_dims.append(i) j += 1 elif core.definitely_equal(idx_shape[i], 1): # If idx_shape[i] == 1, we can just take the entirety of the arr's axis # and avoid forming an iota index. offset_dims.append(i) slice_sizes.append(arr_shape[i]) elif core.definitely_equal(arr_shape[i], 1): # If the array dimension is 1 but the index dimension is not, we # broadcast the array dimension to the index dimension by repeatedly # gathering the first element. gather_indices.append(zeros(gather_index_shape, dtype=index_dtype)) slice_sizes.append(1) start_index_map.append(i) collapsed_slice_dims.append(i) j += 1 else: # Otherwise, idx_shape[i] == arr_shape[i]. Use an iota index so # corresponding elements of array and index are gathered. # TODO(mattjj): next line needs updating for dynamic shapes iota = lax.broadcasted_iota(index_dtype, gather_index_shape, j) gather_indices.append(iota) slice_sizes.append(1) start_index_map.append(i) collapsed_slice_dims.append(i) j += 1 gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), mode="fill" if mode is None else mode)
### Indexing def _is_integer_index(idx: Any) -> bool: return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) def _is_simple_reverse_slice(idx: Any) -> bool: return (isinstance(idx, slice) and idx.start is idx.stop is None and isinstance(idx.step, int) and idx.step == -1) def _is_valid_integer_index_for_slice(idx, size, mode): if size == 0: return False if _is_integer_index(idx): return -size <= idx < size try: shape, dtype = np.shape(idx), _dtype(idx) except: return False if shape == () and np.issubdtype(dtype, np.integer): # For dynamic integer indices, dynamic_slice semantics require index clipping: return mode in [None, 'promise_inbounds', 'clip'] return False def _is_contiguous_slice(idx): return (isinstance(idx, slice) and (idx.start is None or _is_integer_index(idx.start)) and (idx.stop is None or _is_integer_index(idx.stop)) and (idx.step is None or (_is_integer_index(idx.step) and idx.step == 1))) def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> Array | None: # attempt to compute _rewriting_take via lax.slice(); return None if not possible. idx = idx if isinstance(idx, tuple) else (idx,) if not all(isinstance(i, int) for i in arr.shape): return None if len(idx) > arr.ndim: return None if any(i is None for i in idx): return None # TODO(jakevdp): handle newaxis case # For symbolic dimensions fallback to gather if any(core.is_symbolic_dim(elt) for i in idx if isinstance(i, slice) for elt in (i.start, i.stop, i.step)): return None simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) if _is_valid_integer_index_for_slice(ind, size, mode)} contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)} # For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as # opposed to x[:]) lead to incorrect sharding semantics when computed via # dynamic_slice, so we fall back to gather. # TODO(yashkatariya): fix dynamic_slice with sharding is_sharded = (isinstance(arr, ArrayImpl) and not dispatch.is_single_device_sharding(arr.sharding)) has_partial_slices = any(idx[i].indices(arr.shape[i]) != (0, arr.shape[i], 1) for i in contiguous_slices) if is_sharded and (int_indices or has_partial_slices): return None if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx): return None if simple_revs: arr = lax.rev(arr, tuple(simple_revs)) idx = tuple(slice(None) if i in simple_revs else ind for i, ind in enumerate(idx)) contiguous_slices |= simple_revs if not (int_indices or has_partial_slices): return arr idx += (arr.ndim - len(idx)) * (slice(None),) start_indices: Sequence[ArrayLike] = [] slice_sizes: Sequence[int] = [] for ind, size in safe_zip(idx, arr.shape): if isinstance(ind, slice): start, stop, step = ind.indices(size) assert step == 1 # checked above start_indices.append(start) slice_sizes.append(max(0, stop - start)) else: assert np.issubdtype(_dtype(ind), np.integer) # checked above assert np.shape(ind) == () # checked above start_indices.append(ind) slice_sizes.append(1) # Try to use static slicing when possible. if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): int_start_indices = [int(i) for i in start_indices] # type: ignore int_limit_indices = [i + s for i, s in zip(int_start_indices, slice_sizes)] arr = lax.slice( arr, start_indices=int_start_indices, limit_indices=int_limit_indices) else: # We must be careful with dtypes because dynamic_slice requires all # start indices to have matching types. if len(start_indices) > 1: start_indices = util.promote_dtypes(*start_indices) arr = lax.dynamic_slice( arr, start_indices=start_indices, slice_sizes=slice_sizes) if int_indices: arr = lax.squeeze(arr, tuple(int_indices)) return arr def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. # For simplicity of generated primitives, we call lax.dynamic_slice in the # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if config.dynamic_shapes.value and arr.ndim > 0: try: aval = core.get_aval(idx) except: pass else: if (isinstance(aval, core.DShapedArray) and aval.shape == () and dtypes.issubdtype(aval.dtype, np.integer) and not dtypes.issubdtype(aval.dtype, dtypes.bool_) and isinstance(arr.shape[0], int)): return lax.dynamic_index_in_dim(arr, idx, keepdims=False) treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape) return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value) # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(1, 2)) def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value): idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update y = arr if fill_value is not None: core.concrete_or_error(None, fill_value, "fill_value argument to indexed get()") if np.ndim(fill_value) != 0: raise ValueError("fill_value argument to indexed get() must be a scalar") if isinstance(fill_value, np.ndarray): fill_value = fill_value.item() if indexer.scalar_bool_dims: y = lax.expand_dims(y, indexer.scalar_bool_dims) # Avoid calling gather if the slice shape is empty, both as a fast path and to # handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return zeros_like(y, shape=indexer.slice_shape) # We avoid generating a gather when indexer.gather_indices.size is empty. if not core.is_empty_shape(indexer.gather_indices.shape): y = lax.gather( y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape, unique_indices=unique_indices or indexer.unique_indices, indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted, mode=mode, fill_value=fill_value) # Reverses axes with negative strides. if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # This adds np.newaxis/None dimensions. return expand_dims(y, indexer.newaxis_dims) class _Indexer(NamedTuple): # The expected shape of the slice output. slice_shape: Sequence[int] # The slice shape to pass to lax.gather(). gather_slice_shape: Sequence[int] # The gather indices to use. gather_indices: ArrayLike # A GatherDimensionNumbers object describing the gather to perform. dnums: lax.GatherDimensionNumbers # Are the gather_indices known to be non-overlapping and/or sorted? # (In practice, these translate to "there no advanced indices", because # only advanced indices could lead to index repetition.) unique_indices: bool indices_are_sorted: bool # Slice dimensions that have negative strides, and so must be reversed after # the gather. reversed_y_dims: Sequence[int] # Keep track of any axes created by `newaxis`. These must be inserted for # gathers and eliminated for scatters. newaxis_dims: Sequence[int] # Keep track of dimensions with scalar bool indices. These must be inserted # for gathers before performing other index operations. scalar_bool_dims: Sequence[int] def _split_index_for_jit(idx, shape): """Splits indices into necessarily-static and dynamic parts. Used to pass indices into `jit`-ted function. """ # Convert list indices to tuples in cases (deprecated by NumPy.) idx = _eliminate_deprecated_list_indexing(idx) if any(isinstance(i, str) for i in idx): raise TypeError(f"JAX does not support string indexing; got {idx=}") # Expand any (concrete) boolean indices. We can then use advanced integer # indexing logic to handle them. idx = _expand_bool_indices(idx, shape) leaves, treedef = tree_flatten(idx) dynamic = [None] * len(leaves) static = [None] * len(leaves) for i, x in enumerate(leaves): if x is Ellipsis: static[i] = x elif isinstance(x, slice): # slice objects aren't hashable. static[i] = (x.start, x.stop, x.step) else: dynamic[i] = x return treedef, tuple(static), dynamic def _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx): """Recombines indices that were split by _split_index_for_jit.""" idx = [] for s, d in zip(static_idx, dynamic_idx): if d is not None: idx.append(d) elif isinstance(s, tuple): idx.append(slice(s[0], s[1], s[2])) else: idx.append(s) return treedef.unflatten(idx) def _int(aval): return not aval.shape and issubdtype(aval.dtype, integer) def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], normalize_indices: bool = True) -> _Indexer: # Remove ellipses and add trailing slice(None)s. idx = _canonicalize_tuple_index(len(x_shape), idx) # Check for scalar boolean indexing: this requires inserting extra dimensions # before performing the rest of the logic. scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)] if scalar_bool_dims: idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i