jax.lax.conv_general_dilated_local

jax.lax.conv_general_dilated_local#

jax.lax.conv_general_dilated_local(lhs, rhs, window_strides, padding, filter_shape, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, precision=None)[source]#

General n-dimensional unshared convolution operator with optional dilation.

Also known as locally connected layer, the operation is equivalent to convolution with a separate (unshared) rhs kernel used at each output spatial location. Docstring below adapted from jax.lax.conv_general_dilated.

Parameters:
  • lhs (jax.typing.ArrayLike) – a rank n+2 dimensional input array.

  • rhs (jax.typing.ArrayLike) – a rank n+2 dimensional array of kernel weights. Unlike in regular CNNs, its spatial coordinates (H, W, …) correspond to output spatial locations, while input spatial locations are fused with the input channel locations in the single I dimension, in the order of “C” + ‘’.join(c for c in rhs_spec if c not in ‘OI’), where rhs_spec = dimension_numbers[1]. For example, if rhs_spec == “WHIO”, the unfolded kernel shape is `”[output W][output H]{I[receptive window W][receptive window H]}O”.

  • window_strides (Sequence[int]) – a sequence of n integers, representing the inter-window strides.

  • padding (str | Sequence[tuple[int, int]]) – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

  • filter_shape (Sequence[int]) – a sequence of n integers, representing the receptive window spatial shape in the order as specified in rhs_spec = dimension_numbers[1].

  • lhs_dilation (Sequence[int] | None) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of lhs. LHS dilation is also known as transposed convolution.

  • rhs_dilation (Sequence[int] | None) – None, or a sequence of n integers, giving the dilation factor to apply in each input spatial dimension of rhs. RHS dilation is also known as atrous convolution.

  • dimension_numbers (convolution.ConvGeneralDilatedDimensionNumbers | None) – either None, a ConvDimensionNumbers object, or a 3-tuple (lhs_spec, rhs_spec, out_spec), where each element is a string of length n+2.

  • precision (lax.PrecisionLike) – Optional. Either None, which means the default precision for the backend, a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two lax.Precision enums indicating precision of lhs` and rhs.

Return type:

jax.Array

Returns:

An array containing the unshared convolution result.

In the string case of dimension_numbers, each character identifies by position:

  • the batch dimensions in lhs, rhs, and the output with the character ‘N’,

  • the feature dimensions in lhs and the output with the character ‘C’,

  • the input and output feature dimensions in rhs with the characters ‘I’ and ‘O’ respectively, and

  • spatial dimension correspondences between lhs, rhs, and the output using any distinct characters.

For example, to indicate dimension numbers consistent with the conv function with two spatial dimensions, one could use (‘NCHW’, ‘OIHW’, ‘NCHW’). As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use (‘NHWC’, ‘HWIO’, ‘NHWC’). When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not ‘I’ or ‘O’.

If dimension_numbers is None, the default is (‘NCHW’, ‘OIHW’, ‘NCHW’) (for a 2D convolution).