jax.numpy.ufunc

jax.numpy.ufunc#

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, update_doc=False)#

Functions that operate element-by-element on whole arrays.

This is a class for LAX-backed implementations of numpy ufuncs.

Parameters:
  • func (Callable[…, Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, update_doc=False)[source]#
Parameters:
  • func (Callable[…, Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

Methods

__init__(func, /, nin, nout, *[, name, ...])

param func:

accumulate(a[, axis, dtype, out])

Accumulate the result of applying the operator to all elements.

at(a, indices[, b, inplace])

Performs unbuffered in place operation on operand 'a' for elements

outer(A, B, /, **kwargs)

Apply the ufunc op to all pairs (a, b) with a in A and b in B.

reduce(a[, axis, dtype, out, keepdims, ...])

Reduces array's dimension by one, by applying ufunc along one axis.

reduceat(a, indices[, axis, dtype, out])

Performs a (local) reduce with specified slices over a single axis.

Attributes

identity

nargs

nin

nout