jax.experimental.stax module¶
Stax is a small but flexible neural net specification library from scratch.
For an example of its use, see examples/resnet50.py.
-
jax.experimental.stax.
AvgPool
(window_shape, strides=None, padding='VALID', spec=None)¶ Layer construction function for a pooling layer.
-
jax.experimental.stax.
BatchNorm
(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]¶ Layer construction function for a batch normalization layer.
-
jax.experimental.stax.
Conv
(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)¶ Layer construction function for a general convolution layer.
-
jax.experimental.stax.
Conv1DTranspose
(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)¶ Layer construction function for a general transposed-convolution layer.
-
jax.experimental.stax.
ConvTranspose
(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)¶ Layer construction function for a general transposed-convolution layer.
-
jax.experimental.stax.
Dense
(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]¶ Layer constructor function for a dense (fully-connected) layer.
-
jax.experimental.stax.
Dropout
(rate, mode='train')[source]¶ Layer construction function for a dropout layer with given rate.
-
jax.experimental.stax.
FanInConcat
(axis=- 1)[source]¶ Layer construction function for a fan-in concatenation layer.
-
jax.experimental.stax.
GeneralConv
(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]¶ Layer construction function for a general convolution layer.
-
jax.experimental.stax.
GeneralConvTranspose
(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]¶ Layer construction function for a general transposed-convolution layer.
-
jax.experimental.stax.
MaxPool
(window_shape, strides=None, padding='VALID', spec=None)¶ Layer construction function for a pooling layer.
-
jax.experimental.stax.
SumPool
(window_shape, strides=None, padding='VALID', spec=None)¶ Layer construction function for a pooling layer.
-
jax.experimental.stax.
elementwise
(fun, **fun_kwargs)[source]¶ Layer that applies a scalar function elementwise on its inputs.
-
jax.experimental.stax.
parallel
(*layers)[source]¶ Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and FanInSum layers.
- Parameters
*layers – a sequence of layers, each an (init_fun, apply_fun) pair.
- Returns
A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.
-
jax.experimental.stax.
serial
(*layers)[source]¶ Combinator for composing layers in serial.
- Parameters
*layers – a sequence of layers, each an (init_fun, apply_fun) pair.
- Returns
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers.
-
jax.experimental.stax.
shape_dependent
(make_layer)[source]¶ Combinator to delay layer constructor pair until input shapes are known.
- Parameters
make_layer – a one-argument function that takes an input shape as an argument (a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
- Returns
A new layer, meaning an (init_fun, apply_fun) pair, representing the same layer as returned by make_layer but with its construction delayed until input shapes are known.