jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[source]#
Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
- Parameters:
func (
Callable
[...
,Any
]) – a callable that takes nin scalar arguments and return nout outputs.nin (
int
) – integer specifying the number of scalar inputsnout (
int
) – integer specifying the number of scalar outputsidentity (
Optional
[Any
]) – (optional) a scalar specifying the identity of the operation, if any.
- Returns:
jax.numpy.ufunc wrapper of func.
- Return type:
wrapped