jax.lax.conv_with_general_padding¶

jax.lax.conv_with_general_padding(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, precision=None)[source]¶

Convenience wrapper around conv_general_dilated.

Parameters
  • lhs (Any) – a rank n+2 dimensional input array.

  • rhs (Any) – a rank n+2 dimensional array of kernel weights.

  • window_strides (Sequence[int]) – a sequence of n integers, representing the inter-window strides.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

  • lhs_dilation (Optional[Sequence[int]]) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of lhs. LHS dilation is also known as transposed convolution.

  • rhs_dilation (Optional[Sequence[int]]) – None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of rhs. RHS dilation is also known as atrous convolution.

  • precision (Optional[Any]) – Optional. Either None, which means the default precision for the backend, or a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

Return type

Any

Returns

An array containing the convolution result.