jax.lax.conv_general_dilatedÂ¶

jax.lax.
conv_general_dilated
(lhs, rhs, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None, preferred_element_type=None)[source]Â¶ General ndimensional convolution operator, with optional dilation.
Wraps XLAâ€™s Conv operator.
 Parameters
lhs (
Any
) â€“ a rank n+2 dimensional input array.rhs (
Any
) â€“ a rank n+2 dimensional array of kernel weights.window_strides (
Sequence
[int
]) â€“ a sequence of n integers, representing the interwindow strides.padding (
Union
[str
,Sequence
[Tuple
[int
,int
]]]) â€“ either the string â€˜SAMEâ€™, the string â€˜VALIDâ€™, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.lhs_dilation (
Optional
[Sequence
[int
]]) â€“ None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of lhs. LHS dilation is also known as transposed convolution.rhs_dilation (
Optional
[Sequence
[int
]]) â€“ None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of rhs. RHS dilation is also known as atrous convolution.dimension_numbers (
Union
[None
,ConvDimensionNumbers
,Tuple
[str
,str
,str
]]) â€“ either None, aConvDimensionNumbers
object, or a 3tuple(lhs_spec, rhs_spec, out_spec)
, where each element is a string of length n+2.feature_group_count (
int
) â€“ integer, default 1. See XLA HLO docs.batch_group_count (
int
) â€“ integer, default 1. See XLA HLO docs.precision (
Union
[None
,str
,Any
,Tuple
[str
,str
],Tuple
[Any
,Any
]]) â€“ Optional. EitherNone
, which means the default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
), a string (e.g. â€˜highestâ€™ or â€˜fastestâ€™, see thejax.default_matmul_precision
context manager), or a tuple of twolax.Precision
enums or strings indicating precision oflhs
andrhs
.preferred_element_type (
Optional
[Any
]) â€“ Optional. EitherNone
, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
 Return type
 Returns
An array containing the convolution result.
In the string case of
dimension_numbers
, each character identifies by position:the batch dimensions in
lhs
,rhs
, and the output with the character â€˜Nâ€™,the feature dimensions in lhs and the output with the character â€˜Câ€™,
the input and output feature dimensions in rhs with the characters â€˜Iâ€™ and â€˜Oâ€™ respectively, and
spatial dimension correspondences between lhs, rhs, and the output using any distinct characters.
For example, to indicate dimension numbers consistent with the
conv
function with two spatial dimensions, one could use('NCHW', 'OIHW', 'NCHW')
. As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use('NHWC', 'HWIO', 'NHWC')
. When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in therhs_spec
string, so thatwindow_strides[0]
is matched with the dimension corresponding to the first character appearing in rhs_spec that is not'I'
or'O'
.If
dimension_numbers
isNone
, the default is('NCHW', 'OIHW', 'NCHW')
(for a 2D convolution).