jax.numpy.ufunc#
- class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
Universal functions which operation element-by-element on arrays.
JAX implementation of
numpy.ufunc
.This is a class for JAX-backed implementations of NumPy’s ufunc APIs. Most users will never need to instantiate
ufunc
, but rather will use the pre-defined ufuncs injax.numpy
.For constructing your own ufuncs, see
jax.numpy.frompyfunc()
.Examples
Universal functions are functions that apply element-wise to broadcasted arrays, but they also come with a number of extra attributes and methods.
As an example, consider the function
jax.numpy.add
. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:>>> x = jnp.array([1, 2, 3, 4, 5]) >>> jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
Each
ufunc
object includes a number of attributes that describe its behavior:>>> jnp.add.nin # number of inputs 2 >>> jnp.add.nout # number of outputs 1 >>> jnp.add.identity # identity value, or None if no identity exists 0
Binary ufuncs like
jax.numpy.add
include number of methods to apply the function to arrays in different manners.The
outer()
method applies the function to the pair-wise outer-product of the input array values:>>> jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
The
ufunc.reduce()
method perfoms a reduction over the array. For example,jnp.add.reduce()
is equivalent tojnp.sum
:>>> jnp.add.reduce(x) Array(15, dtype=int32)
The
ufunc.accumulate()
method performs a cumulative reduction over the array. For example,jnp.add.accumulate()
is equivalent tojax.numpy.cumulative_sum()
:>>> jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
The
ufunc.at()
method applies the function at particular indices in the array; forjnp.add
the computation is similar tojax.lax.scatter_add()
:>>> jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
And the
ufunc.reduceat()
method performs a number ofreduce
operations bewteen specified indices of an array; forjnp.add
the operation is similar tojax.ops.segment_sum()
:>>> jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
In this case, the first element is
x[0:2].sum()
, and the second element isx[2:].sum()
.- Parameters:
- __init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
- Parameters:
func (Callable[..., Any])
nin (int)
nout (int)
name (str | None | None)
nargs (int | None | None)
identity (Any | None)
call (Callable[..., Any] | None | None)
reduce (Callable[..., Any] | None | None)
accumulate (Callable[..., Any] | None | None)
at (Callable[..., Any] | None | None)
reduceat (Callable[..., Any] | None | None)
Methods
__init__
(func, /, nin, nout, *[, name, ...])accumulate
(a[, axis, dtype, out])Accumulate operation derived from binary ufunc.
at
(a, indices[, b, inplace])Update elements of an array via the specified unary or binary ufunc.
outer
(A, B, /)Apply the function to all pairs of values in
A
andB
.reduce
(a[, axis, dtype, out, keepdims, ...])Reduction operation derived from a binary function.
reduceat
(a, indices[, axis, dtype, out])Reduce an array between specified indices via a binary ufunc.
Attributes
identity
nargs
nin
nout