jax.experimental.optimizers module

Optimizers for use with JAX.

This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or arbitrarily-nested tuple/list/dicts of ndarrays.

An optimizer is modeled as an (init_fun, update_fun, get_params) triple of functions, where the component functions have these signatures:

init_fun(params)

Args:
  params: pytree representing the initial parameters.

Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`.
update_fun(step, grads, opt_state)

Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.

Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state.
get_params(opt_state)

Args:
  opt_state: pytree representing an optimizer state.

Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true.

Notice that an optimizer implementation has a lot of flexibility in the form of opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to the JAX transforms defined in api.py) and it has to be consumable by update_fun and get_params.

class jax.experimental.optimizers.JoinPoint(subtree)[source]

Bases: object

Marks the boundary between two joined (nested) pytrees.

class jax.experimental.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)

Bases: tuple

packed_state

Alias for field number 0

subtree_defs

Alias for field number 2

tree_def

Alias for field number 1

jax.experimental.optimizers.adagrad(step_size, momentum=0.9)[source]

Construct optimizer triple for Adagrad.

Adaptive Subgradient Methods for Online Learning and Stochastic Optimization: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
  • momentum – optional, a positive scalar value for momentum
Returns:

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]

Construct optimizer triple for Adam.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
  • b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
  • b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
  • eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).
Returns:

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.clip_grads(grad_tree, max_norm)[source]

Clip gradients stored as a pytree of arrays to maximum norm max_norm.

jax.experimental.optimizers.constant(step_size)[source]
jax.experimental.optimizers.exponential_decay(step_size, decay_steps, decay_rate)[source]
jax.experimental.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]
jax.experimental.optimizers.l2_norm(tree)[source]

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

jax.experimental.optimizers.make_schedule(scalar_or_schedule)[source]
jax.experimental.optimizers.momentum(step_size, mass)[source]

Construct optimizer triple for SGD with momentum.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
  • mass – positive scalar representing the momentum coefficient.
Returns:

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.nesterov(step_size, mass)[source]

Construct optimizer triple for SGD with Nesterov momentum.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
  • mass – positive scalar representing the momentum coefficient.
Returns:

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.optimizer(opt_maker)[source]

Decorator to make an optimizer defined for arrays generalize to containers.

With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples.

Parameters:opt_maker

a function that returns an (init_fun, update_fun, get_params) triple of functions that might only work with ndarrays, as per

init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray
Returns:An (init_fun, update_fun, get_params) triple of functions that work on arbitrary pytrees, as per
init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray

The OptimizerState pytree type used by the returned functions is isomorphic to ParameterPytree (OptStatePytree ndarray), but may store the state instead as e.g. a partially-flattened data structure for performance.

jax.experimental.optimizers.pack_optimizer_state(marked_pytree)[source]

Converts a marked pytree to an OptimizerState.

The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states.

Parameters:marked_pytree – A pytree containing JoinPoint leaves that hold more pytrees.
Returns:An equivalent OptimizerState to the input argument.
jax.experimental.optimizers.piecewise_constant(boundaries, values)[source]
jax.experimental.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[source]
jax.experimental.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[source]

Construct optimizer triple for RMSProp.

Parameters:step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. gamma: Decay parameter. eps: Epsilon parameter.
Returns:An (init_fun, update_fun, get_params) triple.
jax.experimental.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]

Construct optimizer triple for RMSProp with momentum.

This optimizer is separate from the rmsprop optimizer because it needs to keep track of additional parameters.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
  • gamma – Decay parameter.
  • eps – Epsilon parameter.
  • momentum – Momentum parameter.
Returns:

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.sgd(step_size)[source]

Construct optimizer triple for stochastic gradient descent.

Parameters:step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
Returns:An (init_fun, update_fun, get_params) triple.
jax.experimental.optimizers.sm3(step_size, momentum=0.9)[source]

Construct optimizer triple for SM3.

Memory-Efficient Adaptive Optimization for Large-Scale Learning. https://arxiv.org/abs/1901.11150

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
  • momentum – optional, a positive scalar value for momentum
Returns:

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.unpack_optimizer_state(opt_state)[source]

Converts an OptimizerState to a marked pytree.

Converts an OptimizerState to a marked pytree with the leaves of the outer pytree represented as JoinPoints to avoid losing information. This function is intended to be useful when serializing optimizer states.

Parameters:opt_state – An OptimizerState
Returns:A pytree with JoinPoint leaves that contain a second level of pytrees.