jax.lax.conv_general_dilated_patches

jax.lax.conv_general_dilated_patches(lhs, filter_shape, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, precision=None)[source]

Extract patches subject to the receptive field of conv_general_dilated.

Runs the input through a convolution with given parameters. The kernel of the convolution is constructed such that the output channel dimension “C” contains flattened image patches, so instead a single “C” dimension represents, for example, three dimensions “chw” collapsed. The order of these dimensions is “c” + ‘’.join(c for c in rhs_spec if c not in ‘OI’), where rhs_spec == dimension_numbers[1], and the size of this “C” dimension is therefore the size of each patch, i.e. np.prod(filter_shape) * lhs.shape[lhs_spec.index(‘C’)], where lhs_spec == dimension_numbers[0].

Docstring below adapted from jax.lax.conv_general_dilated.

Parameters
  • lhs (Any) – a rank n+2 dimensional input array.

  • filter_shape (Sequence[int]) – a sequence of n integers, representing the receptive window spatial shape in the order as specified in rhs_spec = dimension_numbers[1].

  • window_strides (Sequence[int]) – a sequence of n integers, representing the inter-window strides.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

  • lhs_dilation (Optional[Sequence[int]]) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of lhs. LHS dilation is also known as transposed convolution.

  • rhs_dilation (Optional[Sequence[int]]) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of rhs. RHS dilation is also known as atrous convolution.

  • dimension_numbers (Union[None, ConvDimensionNumbers, Tuple[str, str, str]]) – either None, or a 3-tuple (lhs_spec, rhs_spec, out_spec), where each element is a string of length n+2. None defaults to (“NCHWD…, OIHWD…, NCHWD…”).

  • precision (Optional[Any]) – Optional. Either None, which means the default precision for the backend, or a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

Return type

Any

Returns

A rank n+2 array containing the flattened image patches in the output channel (“C”) dimension. For example if dimension_numbers = (“NcHW”, “OIwh”, “CNHW”), the output has dimension numbers “CNHW” = “{cwh}NHW”, with the size of dimension “C” equal to the size of each patch (np.prod(filter_shape) * lhs.shape[lhs_spec.index(‘C’)]).