jax.lax.conv_dimension_numbers

jax.lax.conv_dimension_numbers#

jax.lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)[source]#

Converts convolution dimension_numbers to a ConvDimensionNumbers.

Parameters:
  • lhs_shape – tuple of nonnegative integers, shape of the convolution input.

  • rhs_shape – tuple of nonnegative integers, shape of the convolution kernel.

  • dimension_numbers – None or a tuple/list of strings or a ConvDimensionNumbers object following the convolution dimension number specification format in xla_client.py.

Return type:

ConvDimensionNumbers

Returns:

A ConvDimensionNumbers object that represents dimension_numbers in the canonical form used by lax functions.