jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[source]#
Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
- Parameters:
func (Callable[[...], Any]) – a callable that takes nin scalar arguments and return nout outputs.
nin (int) – integer specifying the number of scalar inputs
nout (int) – integer specifying the number of scalar outputs
identity (Any | None) – (optional) a scalar specifying the identity of the operation, if any.
- Returns:
jax.numpy.ufunc wrapper of func.
- Return type:
wrapped