jax.scipy.optimize.minimize#
- jax.scipy.optimize.minimize(fun, x0, args=(), *, method, tol=None, options=None)[source]#
Minimization of scalar function of one or more variables.
This API for this function matches SciPy with some minor deviations:
Gradients of
fun
are calculated automatically using JAX’s autodiff support when required.The
method
argument is required. You must specify a solver.Various optional arguments in the SciPy interface have not yet been implemented.
Optimization results may differ from SciPy due to differences in the line search implementation.
minimize
supportsjit()
compilation. It does not yet support differentiation or arguments in the form of multi-dimensional arrays, but support for both is planned.- Parameters:
fun (
Callable
) – the objective function to be minimized,fun(x, *args) -> float
, wherex
is a 1-D array with shape(n,)
andargs
is a tuple of the fixed parameters needed to completely specify the function.fun
must support differentiation.x0 (
Array
) – initial guess. Array of real elements of size(n,)
, wheren
is the number of independent variables.args (
tuple
) – extra arguments passed to the objective function.method (
str
) – solver type. Currently only"BFGS"
is supported.tol (
Optional
[float
]) – tolerance for termination. For detailed control, use solver-specific options.options (
Optional
[Mapping
[str
,Any
]]) –a dictionary of solver options. All methods accept the following generic options:
maxiter (int): Maximum number of iterations to perform. Depending on the method each iteration may use several function evaluations.
- Return type:
- Returns:
An
OptimizeResults
object.