Source code for jax.experimental.sparse.csr

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CSR (compressed sparse row) matrix object and associated primitives."""
from __future__ import annotations

from functools import partial
import operator
from typing import Optional
import warnings

import numpy as np

import jax
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
from jax import lax
from jax import tree_util
from jax._src import core
from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse
from jax._src.numpy.util import promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
import jax.numpy as jnp


Shape = tuple[int, ...]


[docs] @tree_util.register_pytree_node_class class CSR(JAXSparse): """Experimental CSR matrix implemented in JAX. Note: this class has minimal compatibility with JAX transforms such as grad and autodiff, and offers very little functionality. In general you should prefer :class:`jax.experimental.sparse.BCOO`. Additionally, there are known failures in the case that `nse` is larger than the true number of nonzeros in the represented matrix. This situation is better handled in BCOO. """ data: jax.Array indices: jax.Array indptr: jax.Array shape: tuple[int, int] nse = property(lambda self: self.data.size) dtype = property(lambda self: self.data.dtype) _bufs = property(lambda self: (self.data, self.indices, self.indptr))
[docs] def __init__(self, args, *, shape): self.data, self.indices, self.indptr = map(jnp.asarray, args) super().__init__(args, shape=shape)
@classmethod def fromdense(cls, mat, *, nse=None, index_dtype=np.int32): if nse is None: nse = (mat != 0).sum() return csr_fromdense(mat, nse=nse, index_dtype=index_dtype) @classmethod def _empty(cls, shape, *, dtype=None, index_dtype='int32'): """Create an empty CSR instance. Public method is sparse.empty().""" shape = tuple(shape) if len(shape) != 2: raise ValueError(f"CSR must have ndim=2; got {shape=}") data = jnp.empty(0, dtype) indices = jnp.empty(0, index_dtype) indptr = jnp.zeros(shape[0] + 1, index_dtype) return cls((data, indices, indptr), shape=shape) @classmethod def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): if k > 0: diag_size = min(N, M - k) else: diag_size = min(N + k, M) if diag_size <= 0: # if k is out of range, return an empty matrix. return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype) data = jnp.ones(diag_size, dtype=dtype) idx = jnp.arange(diag_size, dtype=index_dtype) zero = _const(idx, 0) k = _const(idx, k) col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k)) indices = col.astype(index_dtype) # TODO(jakevdp): this can be done more efficiently. row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k)) indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set( jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype))) return cls((data, indices, indptr), shape=(N, M)) def todense(self): return csr_todense(self) def transpose(self, axes=None): assert axes is None return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1]) def __matmul__(self, other): if isinstance(other, JAXSparse): raise NotImplementedError("matmul between two sparse objects.") other = jnp.asarray(other) data, other = promote_dtypes(self.data, other) if other.ndim == 1: return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape) elif other.ndim == 2: return _csr_matmat(data, self.indices, self.indptr, other, shape=self.shape) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") def tree_flatten(self): return (self.data, self.indices, self.indptr), {"shape": self.shape} @classmethod def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj.data, obj.indices, obj.indptr = children if aux_data.keys() != {'shape'}: raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") obj.__dict__.update(**aux_data) return obj
[docs] @tree_util.register_pytree_node_class class CSC(JAXSparse): """Experimental CSC matrix implemented in JAX; API subject to change.""" data: jax.Array indices: jax.Array indptr: jax.Array shape: tuple[int, int] nse = property(lambda self: self.data.size) dtype = property(lambda self: self.data.dtype)
[docs] def __init__(self, args, *, shape): self.data, self.indices, self.indptr = map(jnp.asarray, args) super().__init__(args, shape=shape)
@classmethod def fromdense(cls, mat, *, nse=None, index_dtype=np.int32): if nse is None: nse = (mat != 0).sum() return csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T @classmethod def _empty(cls, shape, *, dtype=None, index_dtype='int32'): """Create an empty CSC instance. Public method is sparse.empty().""" shape = tuple(shape) if len(shape) != 2: raise ValueError(f"CSC must have ndim=2; got {shape=}") data = jnp.empty(0, dtype) indices = jnp.empty(0, index_dtype) indptr = jnp.zeros(shape[1] + 1, index_dtype) return cls((data, indices, indptr), shape=shape) @classmethod def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T def todense(self): return csr_todense(self.T).T def transpose(self, axes=None): assert axes is None return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1]) def __matmul__(self, other): if isinstance(other, JAXSparse): raise NotImplementedError("matmul between two sparse objects.") other = jnp.asarray(other) data, other = promote_dtypes(self.data, other) if other.ndim == 1: return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape[::-1], transpose=True) elif other.ndim == 2: return _csr_matmat(data, self.indices, self.indptr, other, shape=self.shape[::-1], transpose=True) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") def tree_flatten(self): return (self.data, self.indices, self.indptr), {"shape": self.shape} @classmethod def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj.data, obj.indices, obj.indptr = children if aux_data.keys() != {'shape'}: raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}") obj.__dict__.update(**aux_data) return obj
#-------------------------------------------------------------------- # csr_todense csr_todense_p = core.Primitive('csr_todense')
[docs] def csr_todense(mat: CSR) -> Array: """Convert a CSR-format sparse matrix to a dense matrix. Args: mat : CSR matrix Returns: mat_dense: dense version of ``mat`` """ return _csr_todense(mat.data, mat.indices, mat.indptr, shape=mat.shape)
def _csr_todense(data: Array, indices: Array, indptr: Array, *, shape: Shape) -> Array: """Convert CSR-format sparse matrix to a dense matrix. Args: data : array of shape ``(nse,)``. indices : array of shape ``(nse,)`` indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` shape : length-2 tuple representing the matrix shape Returns: mat : array with specified shape and dtype matching ``data`` """ return csr_todense_p.bind(data, indices, indptr, shape=shape) def _csr_todense_impl(data, indices, indptr, *, shape): return _coo_todense(data, *_csr_to_coo(indices, indptr), spinfo=COOInfo(shape=shape)) @csr_todense_p.def_abstract_eval def _csr_todense_abstract_eval(data, indices, indptr, *, shape): assert data.ndim == indices.ndim == indptr.ndim == 1 assert indices.dtype == indptr.dtype assert data.shape == indices.shape assert indptr.shape[0] == shape[0] + 1 return core.ShapedArray(shape, data.dtype) _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, shape): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) return [csr_todense_hlo( data, indices, indptr, shape=shape, data_dtype=dtype, index_dtype=indices_aval.dtype)] def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape): return _csr_todense(data_dot, indices, indptr, shape=shape) def _csr_todense_transpose(ct, data, indices, indptr, *, shape): # Note: we assume that transpose has the same sparsity pattern. # Can we check this? assert ad.is_undefined_primal(data) if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices") assert ct.shape == shape assert indices.aval.dtype == indptr.aval.dtype assert ct.dtype == data.aval.dtype return _csr_extract(indices, indptr, ct), indices, indptr ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None) ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose mlir.register_lowering(csr_todense_p, _csr_todense_lowering) dispatch.simple_impl(csr_todense_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_todense_p, partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_todense_p, partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense), platform='rocm') #-------------------------------------------------------------------- # csr_fromdense csr_fromdense_p = core.Primitive('csr_fromdense') csr_fromdense_p.multiple_results = True
[docs] def csr_fromdense(mat: Array, *, nse: int | None = None, index_dtype: DTypeLike = np.int32) -> CSR: """Create a CSR-format sparse matrix from a dense matrix. Args: mat : array to be converted to CSR. nse : number of specified entries in ``mat``. If not specified, it will be computed from the input matrix. index_dtype : dtype of sparse indices Returns: mat_coo : CSR representation of the matrix. """ if nse is None: nse = int((mat != 0).sum()) nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument") return CSR(_csr_fromdense(mat, nse=nse_int, index_dtype=index_dtype), shape=mat.shape)
def _csr_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = np.int32) -> tuple[Array, Array, Array]: """Create CSR-format sparse matrix from a dense matrix. Args: mat : array to be converted to CSR. nse : number of specified entries in ``mat`` index_dtype : dtype of sparse indices Returns: data : array of shape ``(nse,)`` and dtype ``mat.dtype``. indices : array of shape ``(nse,)`` and dtype ``index_dtype`` indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype`` """ mat = jnp.asarray(mat) nse = core.concrete_or_error(operator.index, nse, "nse argument of csr_fromdense()") return csr_fromdense_p.bind(mat, nse=nse, index_dtype=np.dtype(index_dtype)) def _csr_fromdense_impl(mat, *, nse, index_dtype): mat = jnp.asarray(mat) assert mat.ndim == 2 m = mat.shape[0] row, col = jnp.nonzero(mat, size=nse) data = mat[row, col] true_nonzeros = jnp.arange(nse) < (mat != 0).sum() data = jnp.where(true_nonzeros, data, 0) row = jnp.where(true_nonzeros, row, m) indices = col.astype(index_dtype) indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set( jnp.cumsum(jnp.bincount(row, length=m).astype(index_dtype))) return data, indices, indptr @csr_fromdense_p.def_abstract_eval def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): data = core.ShapedArray((nse,), mat.dtype) indices = core.ShapedArray((nse,), index_dtype) indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype) return data, indices, indptr _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) data, indices, indptr = csr_fromdense_hlo( mat, nnz=nse, index_dtype=np.dtype(index_dtype), data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) return [data, indices, indptr] def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals Mdot, = tangents primals_out = _csr_fromdense(M, nse=nse, index_dtype=index_dtype) data, indices, indptr = primals_out if type(Mdot) is ad.Zero: data_dot = ad.Zero.from_value(data) else: data_dot = _csr_extract(indices, indptr, Mdot) tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) return primals_out, tangents_out def _csr_fromdense_transpose(ct, M, *, nse, index_dtype): data, indices, indptr = ct assert len(data) == nse assert indices.dtype == indptr.dtype == index_dtype if isinstance(indices, ad.Zero) or isinstance(indptr, ad.Zero): raise ValueError("Cannot transpose with respect to sparse indices") assert ad.is_undefined_primal(M) return _csr_todense(data, indices, indptr, shape=M.aval.shape) ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose mlir.register_lowering(csr_fromdense_p, _csr_fromdense_lowering) dispatch.simple_impl(csr_fromdense_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_fromdense_p, partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_fromdense_p, partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense), platform='rocm') #-------------------------------------------------------------------- # csr_matvec csr_matvec_p = core.Primitive('csr_matvec')
[docs] def csr_matvec(mat: CSR, v: Array, transpose: bool = False) -> Array: """Product of CSR sparse matrix and a dense vector. Args: mat : CSR matrix v : one-dimensional array of size ``(shape[0] if transpose else shape[1],)`` and dtype ``mat.dtype`` transpose : boolean specifying whether to transpose the sparse matrix before computing. Returns: y : array of shape ``(mat.shape[1] if transpose else mat.shape[0],)`` representing the matrix vector product. """ data, indices, indptr = mat._bufs return _csr_matvec(data, indices, indptr, v, shape=mat.shape, transpose=transpose)
def _csr_matvec(data, indices, indptr, v, *, shape, transpose=False): """Product of CSR sparse matrix and a dense vector. Args: data : array of shape ``(nse,)``. indices : array of shape ``(nse,)`` indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` v : array of shape ``(shape[0] if transpose else shape[1],)`` and dtype ``data.dtype`` shape : length-2 tuple representing the matrix shape transpose : boolean specifying whether to transpose the sparse matrix before computing. Returns: y : array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose) def _csr_matvec_impl(data, indices, indptr, v, *, shape, transpose): return _coo_matvec(data, *_csr_to_coo(indices, indptr), v, spinfo=COOInfo(shape=shape), transpose=transpose) @csr_matvec_p.def_abstract_eval def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): assert len(shape) == 2 assert v.ndim == data.ndim == indices.ndim == indptr.ndim == 1 assert data.shape == indices.shape assert data.dtype == v.dtype assert indices.dtype == indptr.dtype assert indptr.shape[0] == shape[0] + 1 out_shape = shape[1] if transpose else shape[0] assert v.shape[0] == (shape[0] if transpose else shape[1]) return core.ShapedArray((out_shape,), data.dtype) _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, shape, transpose): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: warnings.warn(f"csr_matvec cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) return [csr_matvec_hlo( data, indices, indptr, v, shape=shape, transpose=transpose, data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose) def _csr_matvec_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data, indices, indptr, v_dot, shape=shape, transpose=transpose) def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): assert not ad.is_undefined_primal(indices) assert not ad.is_undefined_primal(indptr) if ad.is_undefined_primal(v): return data, indices, indptr, _csr_matvec(data, indices, indptr, ct, shape=shape, transpose=not transpose) else: v = jnp.asarray(v) # The following lines do this, but more efficiently. # return _csr_extract(indices, indptr, jnp.outer(ct, v)), indices, indptr, v row, col = _csr_to_coo(indices, indptr) return ct[row] * v[col], indices, indptr, v ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec) ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering) dispatch.simple_impl(csr_matvec_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matvec_p, partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matvec_p, partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec), platform='rocm') #-------------------------------------------------------------------- # csr_matmat csr_matmat_p = core.Primitive('csr_matmat')
[docs] def csr_matmat(mat: CSR, B: Array, *, transpose: bool = False) -> Array: """Product of CSR sparse matrix and a dense matrix. Args: mat : CSR matrix B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and dtype ``mat.dtype`` transpose : boolean specifying whether to transpose the sparse matrix before computing. Returns: C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)`` representing the matrix vector product. """ data, indices, indptr = mat._bufs return _csr_matmat(data, indices, indptr, B, shape=mat.shape, transpose=transpose)
def _csr_matmat(data: Array, indices: Array, indptr: Array, B: Array, *, shape: Shape, transpose: bool = False) -> Array: """Product of CSR sparse matrix and a dense matrix. Args: data : array of shape ``(nse,)``. indices : array of shape ``(nse,)`` indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and dtype ``data.dtype`` shape : length-2 tuple representing the matrix shape transpose : boolean specifying whether to transpose the sparse matrix before computing. Returns: C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product. """ return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose) def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose): return _coo_matmat(data, *_csr_to_coo(indices, indptr), B, spinfo=COOInfo(shape=shape), transpose=transpose) @csr_matmat_p.def_abstract_eval def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): assert len(shape) == 2 assert data.ndim == indices.ndim == indptr.ndim == 1 assert B.ndim == 2 assert data.shape == indices.shape assert data.dtype == B.dtype assert indices.dtype == indptr.dtype assert indptr.shape[0] == shape[0] + 1 out_shape = shape[1] if transpose else shape[0] assert B.shape[0] == (shape[0] if transpose else shape[1]) return core.ShapedArray((out_shape, B.shape[1]), data.dtype) _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, shape, transpose): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: warnings.warn(f"csr_matmat cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) return [csr_matmat_hlo( data, indices, indptr, B, shape=shape, transpose=transpose, index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, B_dtype=B_aval.dtype)] def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose) def _csr_matmat_jvp_right(B_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data, indices, indptr, B_dot, shape=shape, transpose=transpose) def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): assert not ad.is_undefined_primal(indices) assert not ad.is_undefined_primal(indptr) if ad.is_undefined_primal(B): return data, indices, indptr, _csr_matmat(data, indices, indptr, ct, shape=shape, transpose=not transpose) else: B = jnp.asarray(B) row, col = _csr_to_coo(indices, indptr) return (ct[row] * B[col]).sum(1), indices, indptr, B ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right) ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose mlir.register_lowering(csr_matmat_p, _csr_matmat_lowering) dispatch.simple_impl(csr_matmat_p) if gpu_sparse: if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matmat_p, partial(_csr_matmat_gpu_lowering, gpu_sparse.cuda_csr_matmat), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matmat_p, partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat), platform='rocm')