jax.numpy.frompyfunc

Contents

jax.numpy.frompyfunc#

jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[source]#

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

Parameters:
  • func (Callable[..., Any]) – a callable that takes nin scalar arguments and return nout outputs.

  • nin (int) – integer specifying the number of scalar inputs

  • nout (int) – integer specifying the number of scalar outputs

  • identity (Optional[Any]) – (optional) a scalar specifying the identity of the operation, if any.

Returns:

jax.numpy.ufunc wrapper of func.

Return type:

wrapped