# jax.lax.conv_general_dilated¶

jax.lax.conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None)[source]

General n-dimensional convolution operator, with optional dilation.

Wraps XLA’s Conv operator.

Parameters
Return type

Any

Returns

An array containing the convolution result.

In the string case of dimension_numbers, each character identifies by position:

• the batch dimensions in lhs, rhs, and the output with the character ‘N’,

• the feature dimensions in lhs and the output with the character ‘C’,

• the input and output feature dimensions in rhs with the characters ‘I’ and ‘O’ respectively, and

• spatial dimension correspondences between lhs, rhs, and the output using any distinct characters.

For example, to indicate dimension numbers consistent with the conv function with two spatial dimensions, one could use (‘NCHW’, ‘OIHW’, ‘NCHW’). As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use (‘NHWC’, ‘HWIO’, ‘NHWC’). When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not ‘I’ or ‘O’.

If dimension_numbers is None, the default is (‘NCHW’, ‘OIHW’, ‘NCHW’) (for a 2D convolution).