Source code for jax._src.numpy.reductions

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import builtins
from functools import partial
import math
import operator
from typing import (
    overload, Any, Callable, Literal, Optional, Protocol, Sequence, Tuple, Union)
import warnings

import numpy as np

from jax import lax
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
    _broadcast_to, check_arraylike, _complex_elem_type,
    promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
    canonicalize_axis as _canonicalize_axis, maybe_named_axis)


_all = builtins.all
_lax_const = lax_internal._const


Axis = Union[None, int, Sequence[int]]

def _isscalar(element: Any) -> bool:
  if hasattr(element, '__jax_array__'):
    element = element.__jax_array__()
  return dtypes.is_python_scalar(element) or np.isscalar(element)

def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
  # simplified version of jnp.moveaxis() for local use.
  check_arraylike("moveaxis", a)
  a = lax_internal.asarray(a)
  source = _canonicalize_axis(source, np.ndim(a))
  destination = _canonicalize_axis(destination, np.ndim(a))
  perm = [i for i in range(np.ndim(a)) if i != source]
  perm.insert(destination, source)
  return lax.transpose(a, perm)

def _upcast_f16(dtype: DTypeLike) -> DType:
  if np.dtype(dtype) in [np.float16, dtypes.bfloat16]:
    return np.dtype('float32')
  return np.dtype(dtype)

ReductionOp = Callable[[Any, Any], Any]

def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
               *, has_identity: bool = True,
               preproc: Optional[Callable[[ArrayLike], ArrayLike]] = None,
               bool_op: Optional[ReductionOp] = None,
               upcast_f16_for_computation: bool = False,
               axis: Axis = None, dtype: DTypeLike = None, out: None = None,
               keepdims: bool = False, initial: Optional[ArrayLike] = None,
               where_: Optional[ArrayLike] = None,
               parallel_reduce: Optional[Callable[..., Array]] = None,
               promote_integers: bool = False) -> Array:
  bool_op = bool_op or op
  # Note: we must accept out=None as an argument, because numpy reductions delegate to
  # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
  # exists, passing along all its arguments.
  if out is not None:
    raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
  check_arraylike(name, a)
  dtypes.check_user_dtype_supported(dtype, name)
  axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")

  if initial is None and not has_identity and where_ is not None:
    raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
                     f"where mask one has to specify 'initial'")

  a = a if isinstance(a, Array) else lax_internal.asarray(a)
  a = preproc(a) if preproc else a
  pos_dims, dims = _reduction_dims(a, axis)

  if initial is None and not has_identity:
    shape = np.shape(a)
    if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
      raise ValueError(f"zero-size array to reduction operation {name} which has no identity")

  result_dtype = dtype or dtypes.dtype(a)

  if dtype is None and promote_integers:
    # Note: NumPy always promotes to 64-bit; jax instead promotes to the
    # default dtype as defined by dtypes.int_ or dtypes.uint.
    if dtypes.issubdtype(result_dtype, np.bool_):
      result_dtype = dtypes.int_
    elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
      if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
        result_dtype = dtypes.uint
    elif dtypes.issubdtype(result_dtype, np.integer):
      if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
        result_dtype = dtypes.int_

  result_dtype = dtypes.canonicalize_dtype(result_dtype)

  if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
    computation_dtype = _upcast_f16(result_dtype)
  else:
    computation_dtype = result_dtype
  a = lax.convert_element_type(a, computation_dtype)
  op = op if computation_dtype != np.bool_ else bool_op
  # NB: in XLA, init_val must be an identity for the op, so the user-specified
  # initial value must be applied afterward.
  init_val = _reduction_init_val(a, init_val)
  if where_ is not None:
    a = _where(where_, a, init_val)
  if pos_dims is not dims:
    if parallel_reduce is None:
      raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
    result = parallel_reduce(a, dims)
  else:
    result = lax.reduce(a, init_val, op, dims)
  if initial is not None:
    initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype)
    if initial_arr.shape != ():
      raise ValueError("initial value must be a scalar. "
                       f"Got array of shape {initial_arr.shape}")
    result = op(initial_arr, result)
  if keepdims:
    result = lax.expand_dims(result, pos_dims)
  return lax.convert_element_type(result, dtype or result_dtype)

def _canonicalize_axis_allow_named(x, rank):
  return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)

def _reduction_dims(a: ArrayLike, axis: Axis):
  if axis is None:
    return (tuple(range(np.ndim(a))),) * 2
  elif not isinstance(axis, (np.ndarray, tuple, list)):
    axis = (axis,)  # type: ignore[assignment]
  canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
                     for x in axis)  # type: ignore[union-attr]
  if len(canon_axis) != len(set(canon_axis)):
    raise ValueError(f"duplicate value in 'axis': {axis}")
  canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
  if len(canon_pos_axis) != len(canon_axis):
    return canon_pos_axis, canon_axis
  else:
    return canon_axis, canon_axis

def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
  # This function uses np.* functions because lax pattern matches against the
  # specific concrete values of the reduction inputs.
  a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
  if a_dtype == 'bool':
    return np.array(init_val > 0, dtype=a_dtype)
  try:
    return np.array(init_val, dtype=a_dtype)
  except OverflowError:
    assert dtypes.issubdtype(a_dtype, np.integer)
    sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
    return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)

def _cast_to_bool(operand: ArrayLike) -> Array:
  with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=np.ComplexWarning)
    return lax.convert_element_type(operand, np.bool_)

def _cast_to_numeric(operand: ArrayLike) -> Array:
  return promote_dtypes_numeric(operand)[0]


def _ensure_optional_axes(x: Axis) -> Axis:
  def force(x):
    if x is None:
      return None
    try:
      return operator.index(x)
    except TypeError:
      return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
  return core.concrete_or_error(
    force, x, "The axis argument must be known statically.")


# TODO(jakevdp) change promote_integers default to False
_PROMOTE_INTEGERS_DOC = """
promote_integers : bool, default=True
    If True, then integer inputs will be promoted to the widest available integer
    dtype, following numpy's behavior. If False, the result will have the same dtype
    as the input. ``promote_integers`` is ignored if ``dtype`` is specified.
"""


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
                out: None = None, keepdims: bool = False,
                initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
                promote_integers: bool = True) -> Array:
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
                    bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
                    axis=axis, dtype=dtype, out=out, keepdims=keepdims,
                    initial=initial, where_=where, parallel_reduce=lax.psum,
                    promote_integers=promote_integers)

[docs]@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array: return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array: return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers)
[docs]@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array: return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax)
[docs]@_wraps(np.max, skip_params=['out']) def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin)
[docs]@_wraps(np.min, skip_params=['out']) def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where)
[docs]@_wraps(np.all, skip_params=['out']) def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where)
[docs]@_wraps(np.any, skip_params=['out']) def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where)
amin = min amax = max def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]): if not isinstance(axis, (tuple, list)): axis_seq: Sequence[int] = (axis,) # type: ignore[assignment] else: axis_seq = axis size = 1 a_shape = np.shape(a) for a in axis_seq: size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) return size
[docs]@_wraps(np.mean, skip_params=['out']) def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: check_arraylike("mean", a) if dtype is None: dtype = dtypes.to_inexact_dtype(dtypes.dtype(a)) else: dtypes.check_user_dtype_supported(dtype, "mean") dtype = dtypes.canonicalize_dtype(dtype) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) return lax.div( sum(a, axis, dtype=dtype, keepdims=keepdims, where=where), lax.convert_element_type(normalizer, dtype)) @overload def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, returned: Literal[False] = False, keepdims: bool = False) -> Array: ... @overload def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, *, returned: Literal[True], keepdims: bool = False) -> Array: ... @overload def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: ...
[docs]@_wraps(np.average) def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True) def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: if weights is None: # Treat all weights as 1 check_arraylike("average", a) a, = promote_dtypes_inexact(a) avg = mean(a, axis=axis, keepdims=keepdims) if axis is None: weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype) elif isinstance(axis, tuple): weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis)) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] else: check_arraylike("average", a, weights) a, weights = promote_dtypes_inexact(a, weights) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) if axis is None: pass elif isinstance(axis, tuple): axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis) else: axis = _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") elif isinstance(axis, tuple): raise ValueError("Single axis expected when shapes of a and weights differ") elif not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, keepdims=keepdims) avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
[docs]@_wraps(np.var, skip_params=['out']) def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: check_arraylike("var", a) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = lax_internal.asarray(a).astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(computation_dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) computation_dtype = centered.dtype # avoid casting to complex below. else: centered = lax.square(centered) if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) normalizer = lax.convert_element_type(normalizer, computation_dtype) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=computation_dtype, keepdims=keepdims) normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) return lax.div(result, normalizer).astype(dtype) def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> Tuple[DType, DType]: if dtype: if (not dtypes.issubdtype(dtype, np.complexfloating) and dtypes.issubdtype(a_dtype, np.complexfloating)): msg = ("jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " "on https://github.com/google/jax/issues/2283 if this behavior is " "important to you.") raise ValueError(msg) computation_dtype = dtype else: if not dtypes.issubdtype(a_dtype, np.inexact): dtype = dtypes.to_inexact_dtype(a_dtype) computation_dtype = dtype else: dtype = _complex_elem_type(a_dtype) computation_dtype = a_dtype return _upcast_f16(computation_dtype), np.dtype(dtype)
[docs]@_wraps(np.std, skip_params=['out']) def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: check_arraylike("std", a) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") if out is not None: raise NotImplementedError("The 'out' argument to jnp.std is not supported.") return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
[docs]@_wraps(np.ptp, skip_params=['out']) def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@partial(api.jit, static_argnames=('axis', 'keepdims')) def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: check_arraylike("ptp", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") x = amax(a, axis=axis, keepdims=keepdims) y = amin(a, axis=axis, keepdims=keepdims) return lax.sub(x, y)
[docs]@_wraps(np.count_nonzero) @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: check_arraylike("count_nonzero", a) return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, axis: Axis = None, keepdims: bool = False, **kwargs) -> Array: check_arraylike(name, a) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), axis=axis, keepdims=keepdims, **kwargs) if nan_if_all_nan: return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) else: return out
[docs]@_wraps(np.nanmin, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where)
[docs]@_wraps(np.nanmax, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where)
[docs]@_wraps(np.nansum, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
# Work around a sphinx documentation warning in NumPy 1.22. if nansum.__doc__ is not None: nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
[docs]@_wraps(np.nanprod, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
[docs]@_wraps(np.nanmean, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array: check_arraylike("nanmean", a) dtypes.check_user_dtype_supported(dtype, "nanmean") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): return mean(a, axis, dtype, out, keepdims, where=where) if dtype is None: dtype = dtypes.dtype(a) nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a)) normalizer = sum(nan_mask, axis=axis, dtype=np.int32, keepdims=keepdims, where=where) normalizer = lax.convert_element_type(normalizer, dtype) td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer) return td
[docs]@_wraps(np.nanvar, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, ddof: int = 0, keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array: check_arraylike("nanvar", a) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = lax_internal.asarray(a).astype(computation_dtype) a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, lax_internal._zero(normalizer)) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) result = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(result, dtype)
[docs]@_wraps(np.nanstd, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, ddof: int = 0, keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array: check_arraylike("nanstd", a) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
class CumulativeReduction(Protocol): def __call__(self, a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None) -> Array: ... def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction: @_wraps(np_reduction, skip_params=['out']) def cumulative_reduction(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None) -> Array: return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) @partial(api.jit, static_argnames=('axis', 'dtype')) def _cumulative_reduction(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None) -> Array: check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " f"is not supported.") dtypes.check_user_dtype_supported(dtype, np_reduction.__name__) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a),)) if axis is None: axis = 0 a_shape = list(np.shape(a)) num_dims = len(a_shape) axis = _canonicalize_axis(axis, num_dims) if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) if not dtype and dtypes.dtype(a) == np.bool_: dtype = dtypes.canonicalize_dtype(dtypes.int_) if dtype: a = lax.convert_element_type(a, dtype) return reduction(a, axis) return cumulative_reduction cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, fill_nan=True, fill_value=0) nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, fill_nan=True, fill_value=1) # Quantiles
[docs]@_wraps(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("quantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.quantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) if interpolation is not None: warnings.warn("The interpolation= argument to 'quantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
[docs]@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) if interpolation is not None: warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], interpolation: str, keepdims: bool, squash_nans: bool) -> Array: if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " "'midpoint', or 'nearest'") a, = promote_dtypes_inexact(a) keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") if axis is None: a = a.ravel() axis = 0 elif isinstance(axis, tuple): keepdim = list(a.shape) nd = a.ndim axis = tuple(_canonicalize_axis(ax, nd) for ax in axis) if len(set(axis)) != len(axis): raise ValueError('repeated axis') for ax in axis: keepdim[ax] = 1 keep = set(range(nd)) - set(axis) # prepare permutation dimensions = list(range(nd)) for i, s in enumerate(sorted(keep)): dimensions[i], dimensions[s] = dimensions[s], dimensions[i] do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis) touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis) a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions) axis = _canonicalize_axis(-1, a.ndim) else: axis = _canonicalize_axis(axis, a.ndim) q_shape = q.shape q_ndim = q.ndim if q_ndim > 1: raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") a_shape = a.shape if squash_nans: a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) counts = lax.expand_dims(counts, tuple(range(q_ndim))) q = lax.mul(q, lax.sub(counts, _lax_const(q, 1))) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) out_shape = q_shape + shape_after_reduction index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim) for dim in range(len(shape_after_reduction))] if keepdims: index[axis] = low else: index.insert(axis, low) low_value = a[tuple(index)] index[axis] = high high_value = a[tuple(index)] else: a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) low = lax.clamp(_lax_const(low, 0), low, n - 1) high = lax.clamp(_lax_const(high, 0), high, n - 1) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) slice_sizes = list(a_shape) slice_sizes[axis] = 1 dnums = lax.GatherDimensionNumbers( offset_dims=tuple(range( q_ndim, len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)), collapsed_slice_dims=() if keepdims else (axis,), start_index_map=(axis,)) low_value = lax.gather(a, low[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) high_value = lax.gather(a, high[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) if q_ndim == 1: low_weight = lax.broadcast_in_dim(low_weight, low_value.shape, broadcast_dimensions=(0,)) high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, broadcast_dimensions=(0,)) if interpolation == "linear": result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight), lax.mul(high_value.astype(q.dtype), high_weight)) elif interpolation == "lower": result = low_value elif interpolation == "higher": result = high_value elif interpolation == "nearest": pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) elif interpolation == "midpoint": result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) else: raise ValueError(f"interpolation={interpolation!r} not recognized") if keepdims and keepdim: if q_ndim > 0: keepdim = [np.shape(q)[0], *keepdim] result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype)
[docs]@_wraps(np.percentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims)
[docs]@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("nanpercentile", a, q) q = ufuncs.true_divide(q, 100.0) return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims)
[docs]@_wraps(np.median, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: check_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint')
[docs]@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: check_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint')