# Copyright 2022 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import builtins
from import Sequence
from functools import partial
import math
import operator
from typing import overload, Any, Callable, Literal, Protocol, Union
import warnings

import numpy as np

from jax import lax
from jax._src import api
from jax._src import core, config
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
    _broadcast_to, check_arraylike, _complex_elem_type,
    promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
    canonicalize_axis as _canonicalize_axis, maybe_named_axis,

_all = builtins.all
_lax_const = lax_internal._const

Axis = Union[int, Sequence[int], None]

def _isscalar(element: Any) -> bool:
  if hasattr(element, '__jax_array__'):
    element = element.__jax_array__()
  return dtypes.is_python_scalar(element) or np.isscalar(element)

def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
  # simplified version of jnp.moveaxis() for local use.
  check_arraylike("moveaxis", a)
  a = lax_internal.asarray(a)
  source = _canonicalize_axis(source, np.ndim(a))
  destination = _canonicalize_axis(destination, np.ndim(a))
  perm = [i for i in range(np.ndim(a)) if i != source]
  perm.insert(destination, source)
  return lax.transpose(a, perm)

def _upcast_f16(dtype: DTypeLike) -> DType:
  if np.dtype(dtype) in [np.float16, dtypes.bfloat16]:
    return np.dtype('float32')
  return np.dtype(dtype)

ReductionOp = Callable[[Any, Any], Any]

def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
               *, has_identity: bool = True,
               preproc: Callable[[ArrayLike], ArrayLike] | None = None,
               bool_op: ReductionOp | None = None,
               upcast_f16_for_computation: bool = False,
               axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
               keepdims: bool = False, initial: ArrayLike | None = None,
               where_: ArrayLike | None = None,
               parallel_reduce: Callable[..., Array] | None = None,
               promote_integers: bool = False) -> Array:
  bool_op = bool_op or op
  # Note: we must accept out=None as an argument, because numpy reductions delegate to
  # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
  # exists, passing along all its arguments.
  if out is not None:
    raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
  check_arraylike(name, a)
  dtypes.check_user_dtype_supported(dtype, name)
  axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")

  if initial is None and not has_identity and where_ is not None:
    raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
                     f"where mask one has to specify 'initial'")

  a = a if isinstance(a, Array) else lax_internal.asarray(a)
  a = preproc(a) if preproc else a
  pos_dims, dims = _reduction_dims(a, axis)

  if initial is None and not has_identity:
    shape = np.shape(a)
    if not _all(shape[d] >= 1 for d in pos_dims):
      raise ValueError(f"zero-size array to reduction operation {name} which has no identity")

  result_dtype = dtype or dtypes.dtype(a)

  if dtype is None and promote_integers:
    # Note: NumPy always promotes to 64-bit; jax instead promotes to the
    # default dtype as defined by dtypes.int_ or dtypes.uint.
    if dtypes.issubdtype(result_dtype, np.bool_):
      result_dtype = dtypes.int_
    elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
      if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
        result_dtype = dtypes.uint
    elif dtypes.issubdtype(result_dtype, np.integer):
      if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
        result_dtype = dtypes.int_

  result_dtype = dtypes.canonicalize_dtype(result_dtype)

  if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
    computation_dtype = _upcast_f16(result_dtype)
    computation_dtype = result_dtype
  a = lax.convert_element_type(a, computation_dtype)
  op = op if computation_dtype != np.bool_ else bool_op
  # NB: in XLA, init_val must be an identity for the op, so the user-specified
  # initial value must be applied afterward.
  init_val = _reduction_init_val(a, init_val)
  if where_ is not None:
    a = _where(where_, a, init_val)
  if pos_dims is not dims:
    if parallel_reduce is None:
      raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
    result = parallel_reduce(a, dims)
    result = lax.reduce(a, init_val, op, dims)
  if initial is not None:
    initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype)
    if initial_arr.shape != ():
      raise ValueError("initial value must be a scalar. "
                       f"Got array of shape {initial_arr.shape}")
    result = op(initial_arr, result)
  if keepdims:
    result = lax.expand_dims(result, pos_dims)
  return lax.convert_element_type(result, dtype or result_dtype)

def _canonicalize_axis_allow_named(x, rank):
  return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)

def _reduction_dims(a: ArrayLike, axis: Axis):
  if axis is None:
    return (tuple(range(np.ndim(a))),) * 2
  elif not isinstance(axis, (np.ndarray, tuple, list)):
    axis = (axis,)  # type: ignore[assignment]
  canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
                     for x in axis)  # type: ignore[union-attr]
  if len(canon_axis) != len(set(canon_axis)):
    raise ValueError(f"duplicate value in 'axis': {axis}")
  canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
  if len(canon_pos_axis) != len(canon_axis):
    return canon_pos_axis, canon_axis
    return canon_axis, canon_axis

def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
  # This function uses np.* functions because lax pattern matches against the
  # specific concrete values of the reduction inputs.
  a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
  if a_dtype == 'bool':
    return np.array(init_val > 0, dtype=a_dtype)
  if (np.isinf(init_val) and dtypes.issubdtype(a_dtype, np.floating)
      and not dtypes.supports_inf(a_dtype)):
    init_val = np.array(dtypes.finfo(a_dtype).min if np.isneginf(init_val)
                        else dtypes.finfo(a_dtype).max, dtype=a_dtype)
    return np.array(init_val, dtype=a_dtype)
  except OverflowError:
    assert dtypes.issubdtype(a_dtype, np.integer)
    sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
    return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)

def _cast_to_bool(operand: ArrayLike) -> Array:
  with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=NumpyComplexWarning)
    return lax.convert_element_type(operand, np.bool_)

def _cast_to_numeric(operand: ArrayLike) -> Array:
  return promote_dtypes_numeric(operand)[0]

def _ensure_optional_axes(x: Axis) -> Axis:
  def force(x):
    if x is None:
      return None
      return operator.index(x)
    except TypeError:
      return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
  return core.concrete_or_error(
    force, x, "The axis argument must be known statically.")

# TODO(jakevdp) change promote_integers default to False
promote_integers : bool, default=True
    If True, then integer inputs will be promoted to the widest available integer
    dtype, following numpy's behavior. If False, the result will have the same dtype
    as the input. ``promote_integers`` is ignored if ``dtype`` is specified.

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
                out: None = None, keepdims: bool = False,
                initial: ArrayLike | None = None, where: ArrayLike | None = None,
                promote_integers: bool = True) -> Array:
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
                    bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
                    axis=axis, dtype=dtype, out=out, keepdims=keepdims,
                    initial=initial, where_=where, parallel_reduce=lax.psum,

[docs] @implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduction(a, "prod",, lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers)
[docs] @implements(, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax)
[docs] @implements(np.max, skip_params=['out']) def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin)
[docs] @implements(np.min, skip_params=['out']) def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where)
[docs] @implements(np.all, skip_params=['out']) def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where)
[docs] @implements(np.any, skip_params=['out']) def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where)
amin = min amax = max def _axis_size(a: ArrayLike, axis: int | Sequence[int]): if not isinstance(axis, (tuple, list)): axis_seq: Sequence[int] = (axis,) # type: ignore[assignment] else: axis_seq = axis size = 1 a_shape = np.shape(a) for a in axis_seq: size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) return size
[docs] @implements(np.mean, skip_params=['out']) def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: check_arraylike("mean", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") if dtype is None: result_dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True)) else: dtypes.check_user_dtype_supported(dtype, "mean") result_dtype = dtypes.canonicalize_dtype(dtype) if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) else: computation_dtype = result_dtype if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) return lax.div( sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where), lax.convert_element_type(normalizer, computation_dtype) ).astype(result_dtype) @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: Literal[False] = False, keepdims: bool = False) -> Array: ... @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *, returned: Literal[True], keepdims: bool = False) -> Array: ... @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
[docs] @implements(np.average) def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True) def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: if weights is None: # Treat all weights as 1 check_arraylike("average", a) a, = promote_dtypes_inexact(a) avg = mean(a, axis=axis, keepdims=keepdims) if axis is None: weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype) elif isinstance(axis, tuple): weights_sum = lax.full_like(avg,[d]) for d in axis)) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] else: check_arraylike("average", a, weights) a, weights = promote_dtypes_inexact(a, weights) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) if axis is None: pass elif isinstance(axis, tuple): axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis) else: axis = _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") elif isinstance(axis, tuple): raise ValueError("Single axis expected when shapes of a and weights differ") elif not core.definitely_equal(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, keepdims=keepdims) avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
[docs] @implements(np.var, skip_params=['out']) def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = lax_internal.asarray(a).astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(computation_dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) computation_dtype = centered.dtype # avoid casting to complex below. else: centered = lax.square(centered) if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) normalizer = lax.convert_element_type(normalizer, computation_dtype) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=computation_dtype, keepdims=keepdims) normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) return lax.div(result, normalizer).astype(dtype) def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DType, DType]: if dtype: if (not dtypes.issubdtype(dtype, np.complexfloating) and dtypes.issubdtype(a_dtype, np.complexfloating)): msg = ("jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " "on if this behavior is " "important to you.") raise ValueError(msg) computation_dtype = dtype else: if not dtypes.issubdtype(a_dtype, np.inexact): dtype = dtypes.to_inexact_dtype(a_dtype) computation_dtype = dtype else: dtype = _complex_elem_type(a_dtype) computation_dtype = a_dtype return _upcast_f16(computation_dtype), np.dtype(dtype)
[docs] @implements(np.std, skip_params=['out']) def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") if out is not None: raise NotImplementedError("The 'out' argument to jnp.std is not supported.") return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
[docs] @implements(np.ptp, skip_params=['out']) def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@partial(api.jit, static_argnames=('axis', 'keepdims')) def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: check_arraylike("ptp", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") x = amax(a, axis=axis, keepdims=keepdims) y = amin(a, axis=axis, keepdims=keepdims) return lax.sub(x, y)
[docs] @implements(np.count_nonzero) @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: check_arraylike("count_nonzero", a) return sum(, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims)
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, axis: Axis = None, keepdims: bool = False, **kwargs) -> Array: check_arraylike(name, a) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), axis=axis, keepdims=keepdims, **kwargs) if nan_if_all_nan: return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) else: return out
[docs] @implements(np.nanmin, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where)
[docs] @implements(np.nanmax, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where)
[docs] @implements(np.nansum, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
# Work around a sphinx documentation warning in NumPy 1.22. if nansum.__doc__ is not None: nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
[docs] @implements(np.nanprod, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
[docs] @implements(np.nanmean, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: check_arraylike("nanmean", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): return mean(a, axis, dtype, out, keepdims, where=where) if dtype is None: dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True)) else: dtypes.check_user_dtype_supported(dtype, "mean") dtype = dtypes.canonicalize_dtype(dtype) nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a)) normalizer = sum(nan_mask, axis=axis, dtype=dtype, keepdims=keepdims, where=where) td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer) return td
[docs] @implements(np.nanvar, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, where: ArrayLike | None = None) -> Array: check_arraylike("nanvar", a) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = lax_internal.asarray(a).astype(computation_dtype) a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, lax_internal._zero(normalizer)) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) result = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(result, dtype)
[docs] @implements(np.nanstd, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, where: ArrayLike | None = None) -> Array: check_arraylike("nanstd", a) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
class CumulativeReduction(Protocol): def __call__(self, a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: ... # TODO(jakevdp): should we change these semantics to match those of numpy? CUML_REDUCTION_LAX_DESCRIPTION = """ Unlike the numpy counterpart, when ``dtype`` is not specified the output dtype will always match the dtype of the input. """ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction: @implements(np_reduction, skip_params=['out'], lax_description=CUML_REDUCTION_LAX_DESCRIPTION) def cumulative_reduction(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) @partial(api.jit, static_argnames=('axis', 'dtype')) def _cumulative_reduction(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " f"is not supported.") dtypes.check_user_dtype_supported(dtype, np_reduction.__name__) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a),)) if axis is None: axis = 0 a_shape = list(np.shape(a)) num_dims = len(a_shape) axis = _canonicalize_axis(axis, num_dims) if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) if not dtype and dtypes.dtype(a) == np.bool_: dtype = dtypes.canonicalize_dtype(dtypes.int_) if dtype: a = lax.convert_element_type(a, dtype) return reduction(a, axis) return cumulative_reduction cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, fill_nan=True, fill_value=0) nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, fill_nan=True, fill_value=1)
[docs] @implements(getattr(np, 'cumulative_sum', None)) def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, include_initial: bool = False) -> Array: check_arraylike("cumulative_sum", x) x = lax_internal.asarray(x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative sum, however a " "scalar value or scalar array was given." ) if axis is None: axis = 0 if x.ndim > 1: raise ValueError( f"The input array has rank {x.ndim}, however axis was not set to an " "explicit value. The axis argument is only optional for one-dimensional " "arrays.") axis = _canonicalize_axis(axis, x.ndim) dtypes.check_user_dtype_supported(dtype) kind = x.dtype.kind if (dtype is None and kind in {'i', 'u'} and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind]) x = x.astype(dtype=dtype or x.dtype) out = cumsum(x, axis=axis) if include_initial: zeros_shape = list(x.shape) zeros_shape[axis] = 1 out = lax_internal.concatenate( [lax_internal.full(zeros_shape, 0, dtype=out.dtype), out], dimension=axis) return out
# Quantiles
[docs] @implements(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("quantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.quantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) if interpolation is not None: warnings.warn("The interpolation= argument to 'quantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
[docs] @implements(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) if interpolation is not None: warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, interpolation: str, keepdims: bool, squash_nans: bool) -> Array: if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " "'midpoint', or 'nearest'") a, = promote_dtypes_inexact(a) keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") if axis is None: if keepdims: keepdim = [1] * a.ndim a = a.ravel() axis = 0 elif isinstance(axis, tuple): keepdim = list(a.shape) nd = a.ndim axis = tuple(_canonicalize_axis(ax, nd) for ax in axis) if len(set(axis)) != len(axis): raise ValueError('repeated axis') for ax in axis: keepdim[ax] = 1 keep = set(range(nd)) - set(axis) # prepare permutation dimensions = list(range(nd)) for i, s in enumerate(sorted(keep)): dimensions[i], dimensions[s] = dimensions[s], dimensions[i] do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis) touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis) a = lax.reshape(a, do_not_touch_shape + (,), dimensions) axis = _canonicalize_axis(-1, a.ndim) else: axis = _canonicalize_axis(axis, a.ndim) q_shape = q.shape q_ndim = q.ndim if q_ndim > 1: raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") a_shape = a.shape if squash_nans: a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) counts = lax.expand_dims(counts, tuple(range(q_ndim))) q = lax.mul(q, lax.sub(counts, _lax_const(q, 1))) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) out_shape = q_shape + shape_after_reduction index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim) for dim in range(len(shape_after_reduction))] if keepdims: index[axis] = low else: index.insert(axis, low) low_value = a[tuple(index)] index[axis] = high high_value = a[tuple(index)] else: a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) low = lax.clamp(_lax_const(low, 0), low, n - 1) high = lax.clamp(_lax_const(high, 0), high, n - 1) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) slice_sizes = list(a_shape) slice_sizes[axis] = 1 dnums = lax.GatherDimensionNumbers( offset_dims=tuple(range( q_ndim, len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)), collapsed_slice_dims=() if keepdims else (axis,), start_index_map=(axis,)) low_value = lax.gather(a, low[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) high_value = lax.gather(a, high[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) if q_ndim == 1: low_weight = lax.broadcast_in_dim(low_weight, low_value.shape, broadcast_dimensions=(0,)) high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, broadcast_dimensions=(0,)) if interpolation == "linear": result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight), lax.mul(high_value.astype(q.dtype), high_weight)) elif interpolation == "lower": result = low_value elif interpolation == "higher": result = high_value elif interpolation == "nearest": pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) result =, low_value, high_value) elif interpolation == "midpoint": result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) else: raise ValueError(f"interpolation={interpolation!r} not recognized") if keepdims and keepdim: if q_ndim > 0: keepdim = [np.shape(q)[0], *keepdim] result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype)
[docs] @implements(np.percentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims)
[docs] @implements(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("nanpercentile", a, q) q = ufuncs.true_divide(q, 100.0) return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims)
[docs] @implements(np.median, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: check_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint')
[docs] @implements(np.nanmedian, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: check_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint')