

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#

Universal functions which operation element-by-element on arrays.

JAX implementation of numpy.ufunc.

This is a class for JAX-backed implementations of NumPy’s ufunc APIs. Most users will never need to instantiate ufunc, but rather will use the pre-defined ufuncs in jax.numpy.

For constructing your own ufuncs, see jax.numpy.frompyfunc().


Universal functions are functions that apply element-wise to broadcasted arrays, but they also come with a number of extra attributes and methods.

As an example, consider the function jax.numpy.add. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:

>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)

Each ufunc object includes a number of attributes that describe its behavior:

>>> jnp.add.nin  # number of inputs
>>> jnp.add.nout  # number of outputs
>>> jnp.add.identity  # identity value, or None if no identity exists

Binary ufuncs like jax.numpy.add include number of methods to apply the function to arrays in different manners.

The outer() method applies the function to the pair-wise outer-product of the input array values:

>>> jnp.add.outer(x, x)
Array([[ 2,  3,  4,  5,  6],
       [ 3,  4,  5,  6,  7],
       [ 4,  5,  6,  7,  8],
       [ 5,  6,  7,  8,  9],
       [ 6,  7,  8,  9, 10]], dtype=int32)

The ufunc.reduce() method perfoms a reduction over the array. For example, jnp.add.reduce() is equivalent to jnp.sum:

>>> jnp.add.reduce(x)
Array(15, dtype=int32)

The ufunc.accumulate() method performs a cumulative reduction over the array. For example, jnp.add.accumulate() is equivalent to jax.numpy.cumulative_sum():

>>> jnp.add.accumulate(x)
Array([ 1,  3,  6, 10, 15], dtype=int32)

The ufunc.at() method applies the function at particular indices in the array; for jnp.add the computation is similar to jax.lax.scatter_add():

>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101,   2,   3,   4,   5], dtype=int32)

And the ufunc.reduceat() method performs a number of reduce operations bewteen specified indices of an array; for jnp.add the operation is similar to jax.ops.segment_sum():

>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)

In this case, the first element is x[0:2].sum(), and the second element is x[2:].sum().

  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

  • call (Callable[..., Any] | None)

  • reduce (Callable[..., Any] | None)

  • accumulate (Callable[..., Any] | None)

  • at (Callable[..., Any] | None)

  • reduceat (Callable[..., Any] | None)

__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None | None)

  • nargs (int | None | None)

  • identity (Any | None)

  • call (Callable[..., Any] | None | None)

  • reduce (Callable[..., Any] | None | None)

  • accumulate (Callable[..., Any] | None | None)

  • at (Callable[..., Any] | None | None)

  • reduceat (Callable[..., Any] | None | None)


__init__(func, /, nin, nout, *[, name, ...])

accumulate(a[, axis, dtype, out])

Accumulate operation derived from binary ufunc.

at(a, indices[, b, inplace])

Update elements of an array via the specified unary or binary ufunc.

outer(A, B, /)

Apply the function to all pairs of values in A and B.

reduce(a[, axis, dtype, out, keepdims, ...])

Reduction operation derived from a binary function.

reduceat(a, indices[, axis, dtype, out])

Reduce an array between specified indices via a binary ufunc.




