jax.experimental.pallas.mosaic_gpu module#

Experimental GPU backend for Pallas targeting H100.

These APIs are highly unstable and can change weekly. Use at your own risk.

Classes#

Barrier(num_arrivals[, num_barriers])

GPUBlockSpec([block_shape, index_map, ...])

GPUCompilerParams(*[, approx_math, ...])

Mosaic GPU compiler parameters.

GPUMemorySpace(value)

An enumeration.

Layout(value)

An enumeration.

SwizzleTransform(swizzle)

TilingTransform(tiling)

Represents a tiling transformation for memory refs.

TransposeTransform(permutation)

Transpose a tiled memref.

WGMMAAccumulatorRef(shape, dtype, _init)

Functions#

barrier_arrive(barrier)

Arrives at the given barrier.

barrier_wait(barrier)

Waits on the given barrier.

commit_smem()

Commits all writes to SMEM, making them visible to loads, TMA and WGMMA.

copy_gmem_to_smem(src, dst, barrier)

Asynchronously copies a GMEM reference to a SMEM reference.

copy_smem_to_gmem(src, dst[, predicate])

Asynchronously copies a SMEM reference to a GMEM reference.

emit_pipeline(body, *, grid[, in_specs, ...])

Creates a function to emit a manual pipeline within a Pallas kernel.

layout_cast(x, new_layout)

Casts the layout of the given array.

set_max_registers(n, *, action)

Sets the maximum number of registers owned by a warp.

wait_smem_to_gmem(n[, wait_read_only])

Waits until there are no more than n SMEM->GMEM copies in flight.

wgmma(acc, a, b)

Performs an asynchronous warp group matmul-accumulate on the given references.

wgmma_wait(n)

Waits until there is no more than n WGMMA operations in flight.

Aliases#

ACC

alias of WGMMAAccumulatorRef

GMEM

Alias of jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM.

SMEM

Alias of jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM.