Source code for jax._src.numpy.ufuncs

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pytype: skip-file
"""
Implements ufuncs for jax.numpy.
"""

from functools import partial
import operator
from textwrap import dedent
from typing import Any, Callable, Tuple, Union, overload

import numpy as np

from jax._src.api import jit, custom_jvp
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
   _asarray, _check_arraylike, _promote_args, _promote_args_inexact,
   _promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
   _promote_shapes, _where, _wraps)
from jax import core
from jax import lax

_lax_const = lax_internal._const

_INT_DTYPES = {
  16: np.int16,
  32: np.int32,
  64: np.int64,
}

UnOp = Callable[[ArrayLike], Array]
BinOp = Callable[[ArrayLike, ArrayLike], Array]


def _constant_like(x, const):
  return np.array(const, dtype=dtypes.dtype(x))


def _replace_inf(x: ArrayLike) -> Array:
  return lax.select(isposinf(real(x)), lax_internal._zeros(x), x)


def _one_to_one_unop(
    numpy_fn: Callable[..., Any], lax_fn: UnOp,
    promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp:
  if promote_to_inexact:
    fn = lambda x, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
  else:
    fn = lambda x, /: lax_fn(*_promote_args(numpy_fn.__name__, x))
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()  # type: ignore[union-attr]
    return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
  else:
    return _wraps(numpy_fn, module='numpy')(fn)


def _one_to_one_binop(
    numpy_fn: Callable[..., Any], lax_fn: BinOp,
    promote_to_inexact: bool = False, lax_doc: bool = False,
    promote_to_numeric: bool = False) -> BinOp:
  if promote_to_inexact:
    fn = lambda x1, x2, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
  elif promote_to_numeric:
    fn = lambda x1, x2, /: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
  else:
    fn = lambda x1, x2, /: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()  # type: ignore[union-attr]
    return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
  else:
    return _wraps(numpy_fn, module='numpy')(fn)


def _maybe_bool_binop(
    numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp,
    lax_doc: bool = False) -> BinOp:
  def fn(x1, x2, /):
    x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
    return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()  # type: ignore[union-attr]
    return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
  else:
    return _wraps(numpy_fn, module='numpy')(fn)


def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
  def fn(x1, x2, /):
    x1, x2 =  _promote_args(numpy_fn.__name__, x1, x2)
    # Comparison on complex types are defined as a lexicographic ordering on
    # the (real, imag) pair.
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
      rx = lax.real(x1)
      ry = lax.real(x2)
      return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
                        lax_fn(rx, ry))
    return lax_fn(x1, x2)
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  return _wraps(numpy_fn, module='numpy')(fn)

@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ...
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ...
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]: ...

def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]:
  @_wraps(np_op, update_doc=False, module='numpy')
  @partial(jit, inline=True)
  def op(*args):
    zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
    args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x))
            for x in args)
    return bitwise_op(*_promote_args(np_op.__name__, *args))
  return op


fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: _asarray(x))
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
log = _one_to_one_unop(np.log, lax.log, True)
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
sin = _one_to_one_unop(np.sin, lax.sin, True)
cos = _one_to_one_unop(np.cos, lax.cos, True)
tan = _one_to_one_unop(np.tan, lax.tan, True)
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)

add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(np.minimum, lax.min)
maximum = _one_to_one_binop(np.maximum, lax.max)
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)

greater_equal = _comparison_op(np.greater_equal, lax.ge)
greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)

logical_and = _logical_op(np.logical_and, lax.bitwise_and)
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)


[docs]@_wraps(np.arccosh, module='numpy') @jit def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different # convention than np.arccosh. out = lax.acosh(*_promote_args_inexact("arccosh", x)) if dtypes.issubdtype(out.dtype, np.complexfloating): out = _where(real(out) < 0, lax.neg(out), out) return out
[docs]@_wraps(np.right_shift, module='numpy') @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_numeric(np.right_shift.__name__, x1, x2) lax_fn = lax.shift_right_logical if \ np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2)
[docs]@_wraps(np.absolute, module='numpy') @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: _check_arraylike('absolute', x) dt = dtypes.dtype(x) return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs, module='numpy')(absolute)
[docs]@_wraps(np.rint, module='numpy') @jit def rint(x: ArrayLike, /) -> Array: _check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtype == bool or dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
[docs]@_wraps(np.sign, module='numpy') @jit def sign(x: ArrayLike, /) -> Array: _check_arraylike('sign', x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) return lax.complex( lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) return lax.sign(x)
[docs]@_wraps(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
[docs]@_wraps(np.true_divide, module='numpy') @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_inexact("true_divide", x1, x2) return lax.div(x1, x2)
divide = true_divide
[docs]@_wraps(np.floor_divide, module='numpy') @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) # TODO(mattjj): investigate why subtracting a scalar was causing promotion return _where(select, quotient - 1, quotient) elif dtypes.issubdtype(dtype, np.complexfloating): x1r = lax.real(x1) x1i = lax.imag(x1) x2r = lax.real(x2) x2i = lax.imag(x2) which = lax.ge(lax.abs(x2r), lax.abs(x2i)) rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) return lax.convert_element_type(out, dtype) else: return _float_divmod(x1, x2)[0]
[docs]@_wraps(np.divmod, module='numpy') @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]: x1, x2 = _promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: return _float_divmod(x1, x2)
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]: # see float_divmod in floatobject.c of CPython mod = lax.rem(x1, x2) div = lax.div(lax.sub(x1, mod), x2) ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) mod = lax.select(ind, mod + x2, mod) div = lax.select(ind, div - _constant_like(div, 1), div) return lax.round(div), mod @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: x1, x2 = _promote_args_numeric("power", x1, x2) dtype = dtypes.dtype(x1) if not dtypes.issubdtype(dtype, np.integer): return lax.pow(x1, x2) # Integer power => use binary exponentiation. # TODO(phawkins): add integer pow support to XLA. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) for _ in range(bits): acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) x1 = lax.mul(x1, x1) x2 = lax.shift_right_logical(x2, one) return acc
[docs]@_wraps(np.power, module='numpy') def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: _check_arraylike("power", x1, x2) # Special case for concrete integer scalars: use binary exponentiation. # Using lax.pow may be imprecise for floating-point values; the goal of this # code path is to make sure we end up with a precise output for the common # pattern ``x ** 2`` or similar. if isinstance(core.get_aval(x2), core.ConcreteArray): try: x2 = operator.index(x2) # type: ignore[arg-type] except TypeError: pass else: x1, = _promote_dtypes_numeric(x1) return lax.integer_pow(x1, x2) return _power(x1, x2)
[docs]@custom_jvp @_wraps(np.logaddexp, module='numpy') @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def _wrap_between(x, _a): """Wraps `x` between `[-a, a]`.""" a = _constant_like(x, _a) two_a = _constant_like(x, 2 * _a) zero = _constant_like(x, 0) rem = lax.rem(lax.add(x, a), two_a) rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) return lax.sub(rem, a) @logaddexp.defjvp def _logaddexp_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) primal_out = logaddexp(x1, x2) tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out
[docs]@custom_jvp @_wraps(np.logaddexp2, module='numpy') @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_inexact("logaddexp2", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), _constant_like(x1, np.log(2))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
@logaddexp2.defjvp def _logaddexp2_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) primal_out = logaddexp2(x1, x2) tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out
[docs]@_wraps(np.log2, module='numpy') @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: x, = _promote_args_inexact("log2", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
[docs]@_wraps(np.log10, module='numpy') @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: x, = _promote_args_inexact("log10", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
[docs]@_wraps(np.exp2, module='numpy') @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: x, = _promote_args_inexact("exp2", x) return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
[docs]@_wraps(np.signbit, module='numpy') @jit def signbit(x: ArrayLike, /) -> Array: x, = _promote_args("signbit", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): return lax.full_like(x, False, dtype=np.bool_) elif not dtypes.issubdtype(dtype, np.floating): raise ValueError( "jax.numpy.signbit is not well defined for %s" % dtype) # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to # F32. if dtype == dtypes.bfloat16: dtype = np.dtype('float32') x = lax.convert_element_type(x, dtype) info = dtypes.finfo(dtype) if info.bits not in _INT_DTYPES: raise NotImplementedError( "jax.numpy.signbit only supports 16, 32, and 64-bit types.") int_type = _INT_DTYPES[info.bits] x = lax.bitcast_convert_type(x, int_type) return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def _normalize_float(x): info = dtypes.finfo(dtypes.dtype(x)) int_type = _INT_DTYPES[info.bits] cond = lax.abs(x) < info.tiny x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x) x2 = _where(cond, int_type(-info.nmant), int_type(0)) return lax.bitcast_convert_type(x1, int_type), x2
[docs]@_wraps(np.ldexp, module='numpy') @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: _check_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, x2 = _promote_shapes("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype)) info = dtypes.finfo(dtype) int_type = _INT_DTYPES[info.bits] x1 = lax.convert_element_type(x1, dtype) x2 = lax.convert_element_type(x2, int_type) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = less(x2, -(bias + info.nmant)) overflow_cond = greater(x2, bias) m = lax.full_like(x, 1, dtype=dtype) # denormals cond = less(x2, -bias + 1) x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
[docs]@_wraps(np.frexp, module='numpy') @jit def frexp(x: ArrayLike, /) -> Tuple[Array, Array]: _check_arraylike("frexp", x) x, = _promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x1, x2 = _normalize_float(x) x2 += ((x1 >> info.nmant) & mask) - bias + 1 x1 &= ~(mask << info.nmant) x1 |= (bias - 1) << info.nmant x1 = lax.bitcast_convert_type(x1, dtype) cond = isinf(x) | isnan(x) | (x == 0) x2 = _where(cond, lax_internal._zeros(x2), x2) return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
[docs]@_wraps(np.remainder, module='numpy') @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = _promote_args_numeric("remainder", x1, x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax_internal._ones(x2), x2) trunc_mod = lax.rem(x1, x2) trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod, module='numpy')(remainder)
[docs]@_wraps(np.fmod, module='numpy') @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: _check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax_internal._ones(x2), x2) return lax.rem(*_promote_args_numeric("fmod", x1, x2))
[docs]@_wraps(np.square, module='numpy') @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: _check_arraylike("square", x) x, = _promote_dtypes_numeric(x) return lax.integer_pow(x, 2)
[docs]@_wraps(np.deg2rad, module='numpy') @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: x, = _promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180))
[docs]@_wraps(np.rad2deg, module='numpy') @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: x, = _promote_args_inexact("rad2deg", x) return lax.mul(x, _lax_const(x, 180 / np.pi))
degrees = rad2deg radians = deg2rad
[docs]@_wraps(np.conjugate, module='numpy') @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: _check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else _asarray(x)
conj = conjugate
[docs]@_wraps(np.imag) @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: _check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
[docs]@_wraps(np.real) @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: _check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else _asarray(val)
[docs]@_wraps(np.modf, module='numpy', skip_params=['out']) @jit def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]: _check_arraylike("modf", x) x, = _promote_dtypes_inexact(x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x)) return x - whole, whole
[docs]@_wraps(np.isfinite, module='numpy') @jit def isfinite(x: ArrayLike, /) -> Array: _check_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) elif dtypes.issubdtype(dtype, np.complexfloating): return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) else: return lax.full_like(x, True, dtype=np.bool_)
[docs]@_wraps(np.isinf, module='numpy') @jit def isinf(x: ArrayLike, /) -> Array: _check_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) elif dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) im = lax.imag(x) return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)), lax.eq(lax.abs(im), _constant_like(im, np.inf))) else: return lax.full_like(x, False, dtype=np.bool_)
def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(x, _constant_like(x, infinity)) elif dtypes.issubdtype(dtype, np.complexfloating): raise ValueError("isposinf/isneginf are not well defined for complex types") else: return lax.full_like(x, False, dtype=np.bool_) isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])( lambda x, /, out=None: _isposneginf(np.inf, x, out) ) isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])( lambda x, /, out=None: _isposneginf(-np.inf, x, out) )
[docs]@_wraps(np.isnan, module='numpy') @jit def isnan(x: ArrayLike, /) -> Array: _check_arraylike("isnan", x) return lax.ne(x, x)
[docs]@_wraps(np.heaviside, module='numpy') @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _check_arraylike("heaviside", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, _where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
[docs]@_wraps(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: _check_arraylike("hypot", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) x1 = lax.abs(x1) x2 = lax.abs(x2) x1, x2 = maximum(x1, x2), minimum(x1, x2) return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
[docs]@_wraps(np.reciprocal, module='numpy') @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: _check_arraylike("reciprocal", x) x, = _promote_dtypes_inexact(x) return lax.integer_pow(x, -1)
[docs]@_wraps(np.sinc, update_doc=False) @jit def sinc(x: ArrayLike, /) -> Array: _check_arraylike("sinc", x) x, = _promote_dtypes_inexact(x) eq_zero = lax.eq(x, _lax_const(x, 0)) pi_x = lax.mul(_lax_const(x, np.pi), x) safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x) return _where(eq_zero, _sinc_maclaurin(0, pi_x), lax.div(lax.sin(safe_pi_x), safe_pi_x))
@partial(custom_jvp, nondiff_argnums=(0,)) def _sinc_maclaurin(k, x): # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the monomial term in the jvp rule) # TODO(mattjj): see https://github.com/google/jax/issues/10750 if k % 2: return x * 0 else: return x * 0 + _lax_const(x, (-1) ** (k // 2) / (k + 1)) @_sinc_maclaurin.defjvp def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t