jax.lax.conv_general_dilated_patches¶
- jax.lax.conv_general_dilated_patches(lhs, filter_shape, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, precision=None, preferred_element_type=None)[source]¶
Extract patches subject to the receptive field of conv_general_dilated.
Runs the input through a convolution with given parameters. The kernel of the convolution is constructed such that the output channel dimension “C” contains flattened image patches, so instead a single “C” dimension represents, for example, three dimensions “chw” collapsed. The order of these dimensions is “c” + ‘’.join(c for c in rhs_spec if c not in ‘OI’), where rhs_spec == dimension_numbers[1], and the size of this “C” dimension is therefore the size of each patch, i.e. np.prod(filter_shape) * lhs.shape[lhs_spec.index(‘C’)], where lhs_spec == dimension_numbers[0].
Docstring below adapted from jax.lax.conv_general_dilated.
- Parameters
lhs (
Any
) – a rank n+2 dimensional input array.filter_shape (
Sequence
[int
]) – a sequence of n integers, representing the receptive window spatial shape in the order as specified in rhs_spec = dimension_numbers[1].window_strides (
Sequence
[int
]) – a sequence of n integers, representing the inter-window strides.padding (
Union
[str
,Sequence
[Tuple
[int
,int
]]]) – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.lhs_dilation (
Optional
[Sequence
[int
]]) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of lhs. LHS dilation is also known as transposed convolution.rhs_dilation (
Optional
[Sequence
[int
]]) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of rhs. RHS dilation is also known as atrous convolution.dimension_numbers (
Union
[None
,ConvDimensionNumbers
,Tuple
[str
,str
,str
]]) – either None, or a 3-tuple (lhs_spec, rhs_spec, out_spec), where each element is a string of length n+2. None defaults to (“NCHWD…, OIHWD…, NCHWD…”).precision (
Optional
[Any
]) – Optional. EitherNone
, which means the default precision for the backend, or aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
).preferred_element_type (
Optional
[Any
]) – Optional. EitherNone
, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type
- Returns
A rank n+2 array containing the flattened image patches in the output channel (“C”) dimension. For example if dimension_numbers = (“NcHW”, “OIwh”, “CNHW”), the output has dimension numbers “CNHW” = “{cwh}NHW”, with the size of dimension “C” equal to the size of each patch (np.prod(filter_shape) * lhs.shape[lhs_spec.index(‘C’)]).