jax.experimental.optix module

A composable gradient processing and optimization library for JAX.

The optix module implements a number of composable gradient transformations, typically used in the context of optimizing neural nets.

Each transformation defines:

  • init_fn: Params -> OptState, to initialize (possibly empty) sets of statistics (aka state)

  • update_fn: (Updates, OptState, Optional[Params]) -> (Updates, OptState)

    to transform a parameter update or gradient and update the state

An (optional) chain utility can be used to build custom optimizers by chaining arbitrary sequences of transformations. For any sequence of transformations chain returns a single init_fn and update_fn.

An (optional) apply_updates function can be used to eventually apply the transformed gradients to the set of parameters of interest.

Separating gradient transformations from the parameter update allows to flexibly chain a sequence of transformations of the same gradients, as well as combine multiple updates to the same parameters (e.g. in multi-task settings where the different tasks may benefit from different sets of gradient transformations).

Many popular optimizers can be implemented using optix as one-liners, and, for convenience, we provide aliases for some of the most popular ones.

Example Usage:

opt = optix.sgd(learning_rate) OptData = collections.namedtuple(‘OptData’, ‘step state params’) data = OptData(0, opt.init(params), params)

def step(opt_data):

step, state, params = opt_data value, grads = jax.value_and_grad(loss_fn)(params) updates, state = opt.update(grads, state, params) params = optix.apply_updates(updates, params) return value, OptData(step+1, state, params)

for step in range(steps):

value, opt_data = step(opt_data)

class jax.experimental.optix.AddNoiseState[source]

Bases: tuple

State for adding gradient noise. Contains a count for annealing.

property count

Alias for field number 0

property rng_key

Alias for field number 1

class jax.experimental.optix.ApplyEvery[source]

Bases: tuple

Contains a counter and a gradient accumulator.

property count

Alias for field number 0

property grad_acc

Alias for field number 1

class jax.experimental.optix.ClipByGlobalNormState[source]

Bases: tuple

The clip_by_global_norm transformation is stateless.

class jax.experimental.optix.ClipState[source]

Bases: tuple

The clip transformation is stateless.

class jax.experimental.optix.GradientTransformation[source]

Bases: tuple

Optix optimizers consists of a pair of functions: (initialiser, update).

property init

Alias for field number 0

property update

Alias for field number 1

jax.experimental.optix.InitUpdate

alias of jax.experimental.optix.GradientTransformation

class jax.experimental.optix.ScaleByAdamState[source]

Bases: tuple

State for the Adam algorithm.

property count

Alias for field number 0

property mu

Alias for field number 1

property nu

Alias for field number 2

class jax.experimental.optix.ScaleByRStdDevState[source]

Bases: tuple

State for centered exponential moving average of squares of updates.

property mu

Alias for field number 0

property nu

Alias for field number 1

class jax.experimental.optix.ScaleByRmsState[source]

Bases: tuple

State for exponential root mean-squared (RMS)-normalized updates.

property nu

Alias for field number 0

class jax.experimental.optix.ScaleByScheduleState[source]

Bases: tuple

Maintains count for scale scheduling.

property count

Alias for field number 0

class jax.experimental.optix.ScaleState[source]

Bases: tuple

The scale transformation is stateless.

class jax.experimental.optix.TraceState[source]

Bases: tuple

Holds an aggregation of past updates.

property trace

Alias for field number 0

jax.experimental.optix.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08)[source]
Parameters
Return type

GradientTransformation

jax.experimental.optix.add_noise(eta, gamma, seed)[source]

Add gradient noise.

References

[Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)

Parameters
  • eta (float) – base variance of the gaussian noise added to the gradient.

  • gamma (float) – decay exponent for annealing of the variance.

  • seed (int) – seed for random number generation.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.apply_every(k=1)[source]

accumulate gradients and apply them every k steps.

Parameters

k (int) – apply the update every k steps otherwise accumulate the gradients.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.apply_updates(params, updates)[source]

Applies an update to the corresponding parameters.

This is an (optional) utility functions that applies an update, and returns the updated parameters to the caller. The update itself is typically the result of applying any number of chainable transformations.

Parameters
  • params (Any) – a tree of parameters.

  • updates (Any) – a tree of updates, the tree structure and the shape of the leaf

  • must match that of params. (nodes) –

Return type

Any

Returns

Updated parameters, with same structure and shape as params.

jax.experimental.optix.chain(*args)[source]

Applies a list of chainable update transformations.

Given a sequence of chainable transforms, chain returns an init_fn that constructs a state by concatenating the states of the individual transforms, and returns an update_fn which chains the update transformations feeding the appropriate state to each.

Parameters
Return type

GradientTransformation

Returns

A single (init_fn, update_fn) tuple.

jax.experimental.optix.clip(max_delta)[source]

Clip updates element-wise, to be between -max_delta and +max_delta.

Parameters

max_delta – the maximum absolute value for each element in the update.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.clip_by_global_norm(max_norm)[source]

Clip updates using their global norm.

References

[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

Parameters

max_norm – the maximum global norm for an update.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.global_norm(updates)[source]
Parameters

updates (Any) –

Return type

Any

jax.experimental.optix.noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=0)[source]
Parameters
Return type

GradientTransformation

jax.experimental.optix.rmsprop(learning_rate, decay=0.9, eps=1e-08, centered=False)[source]
Parameters
Return type

GradientTransformation

jax.experimental.optix.scale(step_size)[source]

Scale updates by some fixed scalar step_size.

Parameters

step_size (float) – a scalar corresponding to a fixed scaling factor for updates.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.scale_by_adam(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)[source]

Rescale updates according to the Adam algorithm.

References

[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

Parameters
  • b1 (float) – decay rate for the exponentially weighted average of grads.

  • b2 (float) – decay rate for the exponentially weighted average of squared grads.

  • eps (float) – term added to the denominator to improve numerical stability.

  • eps_root (float) – term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.scale_by_rms(decay=0.9, eps=1e-08)[source]

Rescale updates by the root of the exp. moving avg of the square.

References

[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

Parameters
  • decay (float) – decay rate for the exponentially weighted average of squared grads.

  • eps (float) – term added to the denominator to improve numerical stability.

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.scale_by_schedule(step_size_fn)[source]

Scale updates using a custom schedule for the step_size.

Parameters

step_size_fn (Callable[[ndarray], ndarray]) – a function that takes an update count as input and proposes the step_size to multiply the updates by.

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.scale_by_stddev(decay=0.9, eps=1e-08)[source]

Rescale updates by the root of the centered exp. moving average of squares.

References

[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

Parameters
  • decay (float) – decay rate for the exponentially weighted average of squared grads.

  • eps (float) – term added to the denominator to improve numerical stability.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.

jax.experimental.optix.sgd(learning_rate, momentum=0.0, nesterov=False)[source]
Parameters
Return type

GradientTransformation

jax.experimental.optix.trace(decay, nesterov)[source]

Compute a trace of past updates.

Parameters
  • decay (float) – the decay rate for the tracing of past updates.

  • nesterov (bool) – whether to use Nesterov momentum.

Return type

GradientTransformation

Returns

An (init_fn, update_fn) tuple.