Source code for jax._src.numpy.ufuncs

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Implements ufuncs for jax.numpy.
"""

from functools import partial
import operator
from textwrap import dedent
from typing import Any, Callable, Union, overload

import numpy as np

from jax._src import core
from jax._src import dtypes
from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
   check_arraylike, promote_args, promote_args_inexact,
   promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
   promote_shapes, _where, _wraps, check_no_float0s)

_lax_const = lax._const

_INT_DTYPES = {
  16: np.int16,
  32: np.int32,
  64: np.int64,
}

UnOp = Callable[[ArrayLike], Array]
BinOp = Callable[[ArrayLike, ArrayLike], Array]


def _constant_like(x, const):
  return np.array(const, dtype=dtypes.dtype(x))


def _replace_inf(x: ArrayLike) -> Array:
  return lax.select(isposinf(real(x)), lax._zeros(x), x)


def _one_to_one_unop(
    numpy_fn: Callable[..., Any], lax_fn: UnOp,
    promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp:
  if promote_to_inexact:
    fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
  else:
    fn = lambda x, /: lax_fn(*promote_args(numpy_fn.__name__, x))
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()  # type: ignore[union-attr]
    return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
  else:
    return _wraps(numpy_fn, module='numpy')(fn)


def _one_to_one_binop(
    numpy_fn: Callable[..., Any], lax_fn: BinOp,
    promote_to_inexact: bool = False, lax_doc: bool = False,
    promote_to_numeric: bool = False) -> BinOp:
  if promote_to_inexact:
    fn = lambda x1, x2, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x1, x2))
  elif promote_to_numeric:
    fn = lambda x1, x2, /: lax_fn(*promote_args_numeric(numpy_fn.__name__, x1, x2))
  else:
    fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()  # type: ignore[union-attr]
    return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
  else:
    return _wraps(numpy_fn, module='numpy')(fn)


def _maybe_bool_binop(
    numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp,
    lax_doc: bool = False) -> BinOp:
  def fn(x1, x2, /):
    x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
    return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  if lax_doc:
    doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()  # type: ignore[union-attr]
    return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
  else:
    return _wraps(numpy_fn, module='numpy')(fn)


def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
  def fn(x1, x2, /):
    x1, x2 =  promote_args(numpy_fn.__name__, x1, x2)
    # Comparison on complex types are defined as a lexicographic ordering on
    # the (real, imag) pair.
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
      rx = lax.real(x1)
      ry = lax.real(x2)
      return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
                        lax_fn(rx, ry))
    return lax_fn(x1, x2)
  fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
  fn = jit(fn, inline=True)
  return _wraps(numpy_fn, module='numpy')(fn)

@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ...
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ...
@overload
def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]: ...

def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]:
  @_wraps(np_op, update_doc=False, module='numpy')
  @partial(jit, inline=True)
  def op(*args):
    zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
    args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x))
            for x in args)
    return bitwise_op(*promote_args(np_op.__name__, *args))
  return op


fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x))
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
log = _one_to_one_unop(np.log, lax.log, True)
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
sin = _one_to_one_unop(np.sin, lax.sin, True)
cos = _one_to_one_unop(np.cos, lax.cos, True)
tan = _one_to_one_unop(np.tan, lax.tan, True)
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)

add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(np.minimum, lax.min)
maximum = _one_to_one_binop(np.maximum, lax.max)
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)

greater_equal = _comparison_op(np.greater_equal, lax.ge)
greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)

logical_and: BinOp = _logical_op(np.logical_and, lax.bitwise_and)
logical_not: UnOp = _logical_op(np.logical_not, lax.bitwise_not)
logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor)

[docs]@_wraps(np.arccosh, module='numpy') @jit def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different # convention than np.arccosh. out = lax.acosh(*promote_args_inexact("arccosh", x)) if dtypes.issubdtype(out.dtype, np.complexfloating): out = _where(real(out) < 0, lax.neg(out), out) return out
[docs]@_wraps(np.right_shift, module='numpy') @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2) lax_fn = lax.shift_right_logical if \ np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2)
[docs]@_wraps(np.absolute, module='numpy') @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: check_arraylike('absolute', x) dt = dtypes.dtype(x) return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs, module='numpy')(absolute)
[docs]@_wraps(np.rint, module='numpy') @jit def rint(x: ArrayLike, /) -> Array: check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtype == bool or dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
[docs]@_wraps(np.sign, module='numpy') @jit def sign(x: ArrayLike, /) -> Array: check_arraylike('sign', x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) return lax.complex( lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) return lax.sign(x)
[docs]@_wraps(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
[docs]@_wraps(np.true_divide, module='numpy') @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("true_divide", x1, x2) return lax.div(x1, x2)
divide = true_divide
[docs]@_wraps(np.floor_divide, module='numpy') @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) # TODO(mattjj): investigate why subtracting a scalar was causing promotion return _where(select, quotient - 1, quotient) elif dtypes.issubdtype(dtype, np.complexfloating): x1r = lax.real(x1) x1i = lax.imag(x1) x2r = lax.real(x2) x2i = lax.imag(x2) which = lax.ge(lax.abs(x2r), lax.abs(x2i)) rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) return lax.convert_element_type(out, dtype) else: return _float_divmod(x1, x2)[0]
[docs]@_wraps(np.divmod, module='numpy') @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: x1, x2 = promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: return _float_divmod(x1, x2)
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: # see float_divmod in floatobject.c of CPython mod = lax.rem(x1, x2) div = lax.div(lax.sub(x1, mod), x2) ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) mod = lax.select(ind, mod + x2, mod) div = lax.select(ind, div - _constant_like(div, 1), div) return lax.round(div), mod
[docs]@_wraps(np.power, module='numpy') def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("power", x1, x2) check_no_float0s("power", x1, x2) # We apply special cases, both for algorithmic and autodiff reasons: # 1. for *concrete* integer scalar powers (and arbitrary bases), we use # unrolled binary exponentiation specialized on the exponent, which is # more precise for e.g. x ** 2 when x is a float (algorithmic reason!); # 2. for integer bases and integer powers, use unrolled binary exponentiation # where the number of steps is determined by a max bit width of 64 # (algorithmic reason!); # 3. for integer powers and float/complex bases, we apply the lax primitive # without any promotion of input types because in this case we want the # function to be differentiable wrt its first argument at 0; # 3. for other cases, perform jnp dtype promotion on the arguments then apply # lax.pow. # Case 1: concrete integer scalar powers: if isinstance(core.get_aval(x2), core.ConcreteArray): try: x2 = operator.index(x2) # type: ignore[arg-type] except TypeError: pass else: x1, = promote_dtypes_numeric(x1) return lax.integer_pow(x1, x2) # Handle cases #2 and #3 under a jit: return _power(x1, x2)
@partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: x1, x2 = promote_shapes("power", x1, x2) # not dtypes # Case 2: bool/integer result x1_, x2_ = promote_args_numeric("power", x1, x2) if (dtypes.issubdtype(dtypes.dtype(x1_), np.integer) or dtypes.issubdtype(dtypes.dtype(x1_), np.bool_)): assert np.iinfo(dtypes.dtype(x1_)).bits <= 64 # _pow_int_int assumes <=64bit return _pow_int_int(x1_, x2_) # Case 3: float/complex base with integer power (special autodiff behavior) d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2) if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer): return lax.pow(x1, x2) # Case 4: do promotion first return lax.pow(x1_, x2_) # TODO(phawkins): add integer pow support to XLA. def _pow_int_int(x1, x2): # Integer power => use binary exponentiation. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) for _ in range(bits): acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) x1 = lax.mul(x1, x1) x2 = lax.shift_right_logical(x2, one) return acc
[docs]@custom_jvp @_wraps(np.logaddexp, module='numpy') @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def _wrap_between(x, _a): """Wraps `x` between `[-a, a]`.""" a = _constant_like(x, _a) two_a = _constant_like(x, 2 * _a) zero = _constant_like(x, 0) rem = lax.rem(lax.add(x, a), two_a) rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) return lax.sub(rem, a) @logaddexp.defjvp def _logaddexp_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) primal_out = logaddexp(x1, x2) tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out
[docs]@custom_jvp @_wraps(np.logaddexp2, module='numpy') @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("logaddexp2", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), _constant_like(x1, np.log(2))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
@logaddexp2.defjvp def _logaddexp2_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) primal_out = logaddexp2(x1, x2) tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out
[docs]@_wraps(np.log2, module='numpy') @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: x, = promote_args_inexact("log2", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
[docs]@_wraps(np.log10, module='numpy') @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: x, = promote_args_inexact("log10", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
[docs]@_wraps(np.exp2, module='numpy') @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: x, = promote_args_inexact("exp2", x) return lax.exp2(x)
[docs]@_wraps(np.signbit, module='numpy') @jit def signbit(x: ArrayLike, /) -> Array: x, = promote_args("signbit", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): return lax.full_like(x, False, dtype=np.bool_) elif not dtypes.issubdtype(dtype, np.floating): raise ValueError( "jax.numpy.signbit is not well defined for %s" % dtype) # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to # F32. if dtype == dtypes.bfloat16: dtype = np.dtype('float32') x = lax.convert_element_type(x, dtype) info = dtypes.finfo(dtype) if info.bits not in _INT_DTYPES: raise NotImplementedError( "jax.numpy.signbit only supports 16, 32, and 64-bit types.") int_type = _INT_DTYPES[info.bits] x = lax.bitcast_convert_type(x, int_type) return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
def _normalize_float(x): info = dtypes.finfo(dtypes.dtype(x)) int_type = _INT_DTYPES[info.bits] cond = lax.abs(x) < info.tiny x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x) x2 = _where(cond, int_type(-info.nmant), int_type(0)) return lax.bitcast_convert_type(x1, int_type), x2
[docs]@_wraps(np.ldexp, module='numpy') @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, x2 = promote_shapes("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype)) info = dtypes.finfo(dtype) int_type = _INT_DTYPES[info.bits] x1 = lax.convert_element_type(x1, dtype) x2 = lax.convert_element_type(x2, int_type) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = less(x2, -(bias + info.nmant)) overflow_cond = greater(x2, bias) m = lax.full_like(x, 1, dtype=dtype) # denormals cond = less(x2, -bias + 1) x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
[docs]@_wraps(np.frexp, module='numpy') @jit def frexp(x: ArrayLike, /) -> tuple[Array, Array]: check_arraylike("frexp", x) x, = promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x1, x2 = _normalize_float(x) x2 += ((x1 >> info.nmant) & mask) - bias + 1 x1 &= ~(mask << info.nmant) x1 |= (bias - 1) << info.nmant x1 = lax.bitcast_convert_type(x1, dtype) cond = isinf(x) | isnan(x) | (x == 0) x2 = _where(cond, lax._zeros(x2), x2) return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
[docs]@_wraps(np.remainder, module='numpy') @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric("remainder", x1, x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) trunc_mod = lax.rem(x1, x2) trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod, module='numpy')(remainder)
[docs]@_wraps(np.fmod, module='numpy') @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) return lax.rem(*promote_args_numeric("fmod", x1, x2))
[docs]@_wraps(np.square, module='numpy') @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: check_arraylike("square", x) x, = promote_dtypes_numeric(x) return lax.integer_pow(x, 2)
[docs]@_wraps(np.deg2rad, module='numpy') @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: x, = promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180))
[docs]@_wraps(np.rad2deg, module='numpy') @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: x, = promote_args_inexact("rad2deg", x) return lax.mul(x, _lax_const(x, 180 / np.pi))
degrees = rad2deg radians = deg2rad
[docs]@_wraps(np.conjugate, module='numpy') @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
conj = conjugate
[docs]@_wraps(np.imag) @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
[docs]@_wraps(np.real) @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
[docs]@_wraps(np.modf, module='numpy', skip_params=['out']) @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: check_arraylike("modf", x) x, = promote_dtypes_inexact(x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x)) return x - whole, whole
[docs]@_wraps(np.isfinite, module='numpy') @jit def isfinite(x: ArrayLike, /) -> Array: check_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) elif dtypes.issubdtype(dtype, np.complexfloating): return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) else: return lax.full_like(x, True, dtype=np.bool_)
[docs]@_wraps(np.isinf, module='numpy') @jit def isinf(x: ArrayLike, /) -> Array: check_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) elif dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) im = lax.imag(x) return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)), lax.eq(lax.abs(im), _constant_like(im, np.inf))) else: return lax.full_like(x, False, dtype=np.bool_)
def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(x, _constant_like(x, infinity)) elif dtypes.issubdtype(dtype, np.complexfloating): raise ValueError("isposinf/isneginf are not well defined for complex types") else: return lax.full_like(x, False, dtype=np.bool_) isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])( lambda x, /, out=None: _isposneginf(np.inf, x, out) ) isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])( lambda x, /, out=None: _isposneginf(-np.inf, x, out) )
[docs]@_wraps(np.isnan, module='numpy') @jit def isnan(x: ArrayLike, /) -> Array: check_arraylike("isnan", x) return lax.ne(x, x)
[docs]@_wraps(np.heaviside, module='numpy') @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("heaviside", x1, x2) x1, x2 = promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, _where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
[docs]@_wraps(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("hypot", x1, x2) x1, x2 = promote_dtypes_inexact(x1, x2) x1 = lax.abs(x1) x2 = lax.abs(x2) x1, x2 = maximum(x1, x2), minimum(x1, x2) return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
[docs]@_wraps(np.reciprocal, module='numpy') @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: check_arraylike("reciprocal", x) x, = promote_dtypes_inexact(x) return lax.integer_pow(x, -1)
[docs]@_wraps(np.sinc, update_doc=False) @jit def sinc(x: ArrayLike, /) -> Array: check_arraylike("sinc", x) x, = promote_dtypes_inexact(x) eq_zero = lax.eq(x, _lax_const(x, 0)) pi_x = lax.mul(_lax_const(x, np.pi), x) safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x) return _where(eq_zero, _sinc_maclaurin(0, pi_x), lax.div(lax.sin(safe_pi_x), safe_pi_x))
@partial(custom_jvp, nondiff_argnums=(0,)) def _sinc_maclaurin(k, x): # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the monomial term in the jvp rule) # TODO(mattjj): see https://github.com/google/jax/issues/10750 if k % 2: return x * 0 else: return x * 0 + _lax_const(x, (-1) ** (k // 2) / (k + 1)) @_sinc_maclaurin.defjvp def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t