jax.numpy.ufunc

jax.numpy.ufunc#

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#

Universal functions which operation element-by-element on arrays.

JAX implementation of numpy.ufunc.

This is a class for JAX-backed implementations of NumPy’s ufunc APIs. Most users will never need to instantiate ufunc, but rather will use the pre-defined ufuncs in jax.numpy.

For constructing your own ufuncs, see jax.numpy.frompyfunc().

Examples

Universal functions are functions that apply element-wise to broadcasted arrays, but they also come with a number of extra attributes and methods.

As an example, consider the function jax.numpy.add. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:

>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)

Each ufunc object includes a number of attributes that describe its behavior:

>>> jnp.add.nin  # number of inputs
2
>>> jnp.add.nout  # number of outputs
1
>>> jnp.add.identity  # identity value, or None if no identity exists
0

Binary ufuncs like jax.numpy.add include number of methods to apply the function to arrays in different manners.

The outer() method applies the function to the pair-wise outer-product of the input array values:

>>> jnp.add.outer(x, x)
Array([[ 2,  3,  4,  5,  6],
       [ 3,  4,  5,  6,  7],
       [ 4,  5,  6,  7,  8],
       [ 5,  6,  7,  8,  9],
       [ 6,  7,  8,  9, 10]], dtype=int32)

The ufunc.reduce() method perfoms a reduction over the array. For example, jnp.add.reduce() is equivalent to jnp.sum:

>>> jnp.add.reduce(x)
Array(15, dtype=int32)

The ufunc.accumulate() method performs a cumulative reduction over the array. For example, jnp.add.accumulate() is equivalent to jax.numpy.cumulative_sum():

>>> jnp.add.accumulate(x)
Array([ 1,  3,  6, 10, 15], dtype=int32)

The ufunc.at() method applies the function at particular indices in the array; for jnp.add the computation is similar to jax.lax.scatter_add():

>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101,   2,   3,   4,   5], dtype=int32)

And the ufunc.reduceat() method performs a number of reduce operations bewteen specified indices of an array; for jnp.add the operation is similar to jax.ops.segment_sum():

>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)

In this case, the first element is x[0:2].sum(), and the second element is x[2:].sum().

Parameters:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

  • call (Callable[..., Any] | None)

  • reduce (Callable[..., Any] | None)

  • accumulate (Callable[..., Any] | None)

  • at (Callable[..., Any] | None)

  • reduceat (Callable[..., Any] | None)

__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
Parameters:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None | None)

  • nargs (int | None | None)

  • identity (Any | None)

  • call (Callable[..., Any] | None | None)

  • reduce (Callable[..., Any] | None | None)

  • accumulate (Callable[..., Any] | None | None)

  • at (Callable[..., Any] | None | None)

  • reduceat (Callable[..., Any] | None | None)

Methods

__init__(func, /, nin, nout, *[, name, ...])

accumulate(a[, axis, dtype, out])

Accumulate operation derived from binary ufunc.

at(a, indices[, b, inplace])

Update elements of an array via the specified unary or binary ufunc.

outer(A, B, /)

Apply the function to all pairs of values in A and B.

reduce(a[, axis, dtype, out, keepdims, ...])

Reduction operation derived from a binary function.

reduceat(a, indices[, axis, dtype, out])

Reduce an array between specified indices via a binary ufunc.

Attributes

identity

nargs

nin

nout