jax.numpy.ufunc#

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, update_doc=False)#

Functions that operate element-by-element on whole arrays.

This is a class for LAX-backed implementations of numpy ufuncs.

Parameters:
__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, update_doc=False)[source]#
Parameters:

Methods

__init__(func, /, nin, nout, *[, name, ...])

param func:

accumulate(a[, axis, dtype, out])

Accumulate the result of applying the operator to all elements.

at(a, indices[, b, inplace])

Accumulate the result of applying the operator to all elements.

outer(A, B, /, **kwargs)

Apply the ufunc op to all pairs (a, b) with a in A and b in B.

reduce(a[, axis, dtype, out, keepdims, ...])

Reduces array's dimension by one, by applying ufunc along one axis.

reduceat(a, indices[, axis, dtype, out])

Performs a (local) reduce with specified slices over a single axis.

Attributes

identity

nargs

nin

nout