jax.scipy.optimize.minimize

Contents

jax.scipy.optimize.minimize#

jax.scipy.optimize.minimize(fun, x0, args=(), *, method, tol=None, options=None)[source]#

Minimization of scalar function of one or more variables.

This API for this function matches SciPy with some minor deviations:

  • Gradients of fun are calculated automatically using JAX’s autodiff support when required.

  • The method argument is required. You must specify a solver.

  • Various optional arguments in the SciPy interface have not yet been implemented.

  • Optimization results may differ from SciPy due to differences in the line search implementation.

minimize supports jit() compilation. It does not yet support differentiation or arguments in the form of multi-dimensional arrays, but support for both is planned.

Parameters:
  • fun (Callable) – the objective function to be minimized, fun(x, *args) -> float, where x is a 1-D array with shape (n,) and args is a tuple of the fixed parameters needed to completely specify the function. fun must support differentiation.

  • x0 (jax.Array) – initial guess. Array of real elements of size (n,), where n is the number of independent variables.

  • args (tuple) – extra arguments passed to the objective function.

  • method (str) – solver type. Currently only "BFGS" is supported.

  • tol (float | None) – tolerance for termination. For detailed control, use solver-specific options.

  • options (Mapping[str, Any] | None) –

    a dictionary of solver options. All methods accept the following generic options:

    • maxiter (int): Maximum number of iterations to perform. Depending on the method each iteration may use several function evaluations.

Return type:

OptimizeResults

Returns:

An OptimizeResults object.