jax.experimental.sparse.bcsr_dot_general

jax.experimental.sparse.bcsr_dot_general#

jax.experimental.sparse.bcsr_dot_general(lhs, rhs, *, dimension_numbers, precision=None, preferred_element_type=None)[source]#

A general contraction operation.

Parameters:
  • lhs (BCSR | Array) – An ndarray or BCSR-format sparse array.

  • rhs (Array) – An ndarray or BCSR-format sparse array..

  • dimension_numbers (DotDimensionNumbers) – a tuple of tuples of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)).

  • precision (None) – unused

  • preferred_element_type (None) – unused

Return type:

Array

Returns:

An ndarray or BCSR-format sparse array containing the result. If both inputs are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray.