jax.experimental.optimizers module

Optimizers for use with JAX.

This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or arbitrarily-nested tuple/list/dicts of ndarrays.

An optimizer is modeled as an (init_fun, update_fun, get_params) triple of functions, where the component functions have these signatures:

init_fun(params)

Args:
  params: pytree representing the initial parameters.

Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`.
update_fun(step, grads, opt_state)

Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.

Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state.
get_params(opt_state)

Args:
  opt_state: pytree representing an optimizer state.

Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true.

Notice that an optimizer implementation has a lot of flexibility in the form of opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to the JAX transforms defined in api.py) and it has to be consumable by update_fun and get_params.

Example Usage:

opt = optimizers.sgd(learning_rate)
opt_state = opt.init(params)

def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(opt.get_params(opt_state))
  opt_state = opt.update(step, grads, opt_state)
  return value, opt_state

for step in range(num_steps):
  value, opt_state = step(step, opt_state)
class jax.experimental.optimizers.JoinPoint(subtree)[source]

Bases: object

Marks the boundary between two joined (nested) pytrees.

class jax.experimental.optimizers.Optimizer(init_fn, update_fn, params_fn)[source]

Bases: tuple

property init_fn

Alias for field number 0

property params_fn

Alias for field number 2

property update_fn

Alias for field number 1

class jax.experimental.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)

Bases: tuple

property packed_state

Alias for field number 0

property subtree_defs

Alias for field number 2

property tree_def

Alias for field number 1

jax.experimental.optimizers.adagrad(step_size, momentum=0.9)[source]

Construct optimizer triple for Adagrad.

Adaptive Subgradient Methods for Online Learning and Stochastic Optimization: http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

Parameters
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • momentum – optional, a positive scalar value for momentum

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]

Construct optimizer triple for Adam.

Parameters
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).

  • b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).

  • eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]

Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).

Parameters
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • b1 – optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).

  • b2 – optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).

  • eps – optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.clip_grads(grad_tree, max_norm)[source]

Clip gradients stored as a pytree of arrays to maximum norm max_norm.

jax.experimental.optimizers.constant(step_size)[source]
Return type

Callable[[int], float]

jax.experimental.optimizers.exponential_decay(step_size, decay_steps, decay_rate)[source]
jax.experimental.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]
jax.experimental.optimizers.l2_norm(tree)[source]

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

jax.experimental.optimizers.make_schedule(scalar_or_schedule)[source]
Parameters

scalar_or_schedule (Union[float, Callable[[int], float]]) –

Return type

Callable[[int], float]

jax.experimental.optimizers.momentum(step_size, mass)[source]

Construct optimizer triple for SGD with momentum.

Parameters
  • step_size (Callable[[int], float]) – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • mass (float) – positive scalar representing the momentum coefficient.

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.nesterov(step_size, mass)[source]

Construct optimizer triple for SGD with Nesterov momentum.

Parameters
  • step_size (Callable[[int], float]) – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • mass (float) – positive scalar representing the momentum coefficient.

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.optimizer(opt_maker)[source]

Decorator to make an optimizer defined for arrays generalize to containers.

With this decorator, you can write init, update, and get_params functions that each operate only on single arrays, and convert them to corresponding functions that operate on pytrees of parameters. See the optimizers defined in optimizers.py for examples.

Parameters

opt_maker (Callable[…, Tuple[Callable[[Any], Any], Callable[[int, Any, Any], Any], Callable[[Any], Any]]]) –

a function that returns an (init_fun, update_fun, get_params) triple of functions that might only work with ndarrays, as per

init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray

Return type

Callable[…, Optimizer]

Returns

An (init_fun, update_fun, get_params) triple of functions that work on arbitrary pytrees, as per

init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray

The OptimizerState pytree type used by the returned functions is isomorphic to ParameterPytree (OptStatePytree ndarray), but may store the state instead as e.g. a partially-flattened data structure for performance.

jax.experimental.optimizers.pack_optimizer_state(marked_pytree)[source]

Converts a marked pytree to an OptimizerState.

The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states.

Parameters

marked_pytree – A pytree containing JoinPoint leaves that hold more pytrees.

Returns

An equivalent OptimizerState to the input argument.

jax.experimental.optimizers.piecewise_constant(boundaries, values)[source]
Parameters
  • boundaries (Any) –

  • values (Any) –

jax.experimental.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[source]
jax.experimental.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[source]

Construct optimizer triple for RMSProp.

Parameters

step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. gamma: Decay parameter. eps: Epsilon parameter.

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]

Construct optimizer triple for RMSProp with momentum.

This optimizer is separate from the rmsprop optimizer because it needs to keep track of additional parameters.

Parameters
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • gamma – Decay parameter.

  • eps – Epsilon parameter.

  • momentum – Momentum parameter.

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.sgd(step_size)[source]

Construct optimizer triple for stochastic gradient descent.

Parameters

step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.sm3(step_size, momentum=0.9)[source]

Construct optimizer triple for SM3.

Memory-Efficient Adaptive Optimization for Large-Scale Learning. https://arxiv.org/abs/1901.11150

Parameters
  • step_size – positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.

  • momentum – optional, a positive scalar value for momentum

Returns

An (init_fun, update_fun, get_params) triple.

jax.experimental.optimizers.unpack_optimizer_state(opt_state)[source]

Converts an OptimizerState to a marked pytree.

Converts an OptimizerState to a marked pytree with the leaves of the outer pytree represented as JoinPoints to avoid losing information. This function is intended to be useful when serializing optimizer states.

Parameters

opt_state – An OptimizerState

Returns

A pytree with JoinPoint leaves that contain a second level of pytrees.