jax.experimental.pallas.mosaic_gpu
module#
Experimental GPU backend for Pallas targeting H100.
These APIs are highly unstable and can change weekly. Use at your own risk.
Classes#
|
|
|
|
|
Mosaic GPU compiler parameters. |
|
An enumeration. |
|
An enumeration. |
|
|
|
Represents a tiling transformation for memory refs. |
|
Transpose a tiled memref. |
|
Functions#
|
Arrives at the given barrier. |
|
Waits on the given barrier. |
Commits all writes to SMEM, making them visible to loads, TMA and WGMMA. |
|
|
Asynchronously copies a GMEM reference to a SMEM reference. |
|
Asynchronously copies a SMEM reference to a GMEM reference. |
|
Creates a function to emit a manual pipeline within a Pallas kernel. |
|
Casts the layout of the given array. |
|
Sets the maximum number of registers owned by a warp. |
|
Waits until there are no more than |
|
Performs an asynchronous warp group matmul-accumulate on the given references. |
|
Waits until there is no more than |
Aliases#
alias of |
|
Alias of |
|
Alias of |