Pipelining#
In this guide we’ll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute.
#@title Imports
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
TPU and its memory spaces#
A TPU and its TensorCore consist of memory spaces (where arrays can reside),
registers (which temporarily store scalar and array values) and compute units
(that do computation with values in registers).
Below is a diagram of a TPU in which x
and y
are arrays that live in
high-bandwidth memory (HBM):
Let’s talk about the components of this diagram in more detail:
Memory spaces: A TPU has high-bandwidth memory (HBM) which is what we often think of as “device memory”. There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values.
Registers: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs).
Compute units: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well.
In order to do a vectorized computation on our values x
and y
that live
in HBM, we need to:
Copy the values
x
andy
into VMEM.Load the values from VMEM into VREGs.
Execute the computation using the VPU or MXU, storing the output in VREGs.
Store the values in the output VREGs into VMEM.
Copy the output values in VMEM back to HBM.
Let’s implement a Pallas function that does just that!
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
# It will then copy `x` and `y` from HBM into VMEM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from VMEM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
We’ve written two functions: add_matrices_kernel
and add_matrices
.
add_matrices_kernel
operates using Ref
s that live in VMEM.
Loading from a VMEM Ref
produces a value that lives in VREGs.
Values in VREGs behave like jax.Array
s in that we can use jnp
and
jax.lax
operations on them to produce new values that live in VREGs.
When we produce the values we’d like to return, we store them in the output
VMEM Ref
.
The add_matrices
function acts on jax.Array
s and returns a jax.Array
.
Inside it, we pass x
and y
into pallas_call
.
pallas_call
is responsible for copying x
and y
into VMEM and for
allocating the VMEM buffers that the kernel operates on (including allocating
z_vmem_ref
, the output VMEM buffer).
After the kernel function is finished running, pallas_call
will also copy
the value in z_vmem_ref
to HBM, resulting in an output jax.Array
.
Constraints of using VMEM/SMEM#
Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations.
Memory capacity. VMEM and SMEM are small! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won’t even be able to fit them into VMEM at all. For reference, a
f32[2048, 2048]
array is 16MiB, so our above kernel won’t scale beyond moderately sized arrays.Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The
add_matrices
function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself.
With these two constraints in mind, we’ll have to rethink our strategy for getting performance out of our TPUs.
Primer: Pipelining#
Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining?
The goal is: in parallel copy to/from HBM and VMEM while utilizing our
compute units.
Naively this is difficult because in our program above we copy all of x
and y
before we start doing any compute with them, creating a dependence
between the copy and the compute.
However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of “blocks” of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let’s walk through a simple example:
Let’s say we split our arrays x
and y
into x1, x2
and y1, y2
(for
example, split along the leading axis, resulting in two (256, 512)
arrays
for each input.
We can now execute the following pipelined computation.
Copy
x1
andy1
into VMEM.Start copying
x2
andy2
into VMEMLoad
x1, y1
from VMEM into VREGs.Execute the
z1 = x1 + y1
using the compute units.Store
z1
into VMEM.Start copying
z1
from VMEM back into HBM.Wait until
x2, y2
have been copied into VMEM.Load
x2, y2
from VMEM into VREGs.Execute the
z2 = x2 + y2
using the compute units.Store
z2
into VMEM.Wait until
z1
is copied into HBM.Start copying
z2
from VMEM back into HBM.Wait until
z2
is copied into HBM.
Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted.
The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the arithmetic intensity of an operation and determines if our pipeline will be compute bound or memory bound.
Pipelining in Pallas#
How do we implement a pipeline like the one above in Pallas?
It seems like a complex sequence of asynchronous data operations and
executing kernels that would be a pain to implement manually.
Fear not! Pallas offers an API for expressing pipelines without too much
boilerplate, namely through grid
s and BlockSpec
s.
See how in the above pipelined example, we are executing the same logic
multiple times: steps 3-5 and 8-10 both execute the same operations,
only on different inputs.
The jax.experimental.pallas.pallas_call()
provides a way to
execute a kernel multiple times, by using the grid
argument.
See grid, a.k.a. kernels in a loop.
We also use jax.experimental.pallas.BlockSpec
to specify
how to construct the input of each kernel invocation.
See BlockSpec, a.k.a. how to chunk up inputs.
In the pipelining example above, we had (512, 512)
-shaped arrays and
split them along the leading dimension into two (256, 512)
-shaped arrays.
In this pipeline, our BlockSpec.block_shape
would be (256, 512)
.
On the 1st iteration we’d
like to select x1
and on the second iteration we’d like to use x2
.
This can be expressed with the following index_map
:
def x_index_map(i):
return (i, 0)
We’d then construct the BlockSpec
:
block_spec = pl.BlockSpec((256, 512), x_index_map)
The BlockSpec
s for y
and z
will be the same as the one for x
.
Putting it together#
We provide these arguments to pallas_call
via grid
, in_specs
and
out_specs
(in_specs
corresponds to the tuple of positional arguments,
and out_specs
corresponds to the output).
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,)
)(x, y)
add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
We’ve only added a little bit of code to our original function to add
automatic pipelining but the BlockSpec
s and grid
do a lot of heavy
lifting!
How does it work? Well, the BlockSpec
s provide enough information to start
prefetching blocks of our input from HBM into VMEM.
For example, if we are starting iteration i
of our grid
, we can pass
i + 1
into the index_map
functions to obtain the blocks needed for the
next iteration. We can then start an asynchronous copy for those blocks.
Similarly for outputs, we can wait for the outputs of the previous iteration
to be copied before starting the copy for the current iteration’s outputs.
Parameterizing a pipeline#
It’s common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do).
Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let’s write a more general kernel that handles both of these features.
def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
Handling reductions#
How would you implement something like jnp.sum
using pallas_call
?
Specifically, we’d like to pipeline across the reduction dimension.
Take the example of reducing a (8, 512, 512)
-shaped array to a
(512, 512)
-shaped one.
x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
To do this using pallas_call
, we could use a grid of size (8,)
and in
each iteration i
load x[i]
into VMEM.
Then we could add x[i]
to an output VMEM buffer. Let’s implement this
naively first.
# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
)(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
...,
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.]], dtype=float32)
Notice how we’ve set up the BlockSpec
s: we’re loading the entirety of
the (512, 512)
dimension into VMEM (no pipelining there) but selecting
the i
-th dimension of x
each iteration in the index_map
.
We are using a None
for that dimension in the block shape, which indicates
that we are selecting a singleton dimension from x
that we would like
to squeeze away in the kernel.
Therefore, x_ref
is (512, 512)
-shaped in VMEM as well.
out_spec
uses lambda i: (0, 0)
as its index_map
, indicating that
o_ref
is unchanged over the course of the pipeline.
This means that we can update its value each iteration by reading from and
writing to it. Or can it?
Actually there is one catch: o_ref
is initially garbage, meaning we’ll
be accumulating into garbage.
This will result in the overall function outputting the incorrect value!
Therefore, whenever we do a reduction in a kernel, we need to make sure
to initialize the Ref
that is storing the reduced value.
We can accomplish this by conditionally writing a value to out_ref
when we’re on iteration 0.
We can do this with the helper function pl.when
, a convenience wrapper
around jax.lax.cond
, and pl.program_id
,
which queries which iteration in a grid axis we are in.
def sum_kernel(x_ref, o_ref):
@pl.when(pl.program_id(axis=0) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
This sum
function now outputs the correct values!
One last thing to note about reductions in Pallas are that they must be
done in the minormost (rightmost) dimensions of our grid (our grid is
1-dimensional in the above example so we are reducing over its minormost
dimension). This is because the pipeline that Pallas generates using
the BlockSpec
s, grid
and kernel function does not read outputs back
from HBM.
Once you’ve written an output value back to HBM you cannot revisit it.
Therefore, you cannot do a reduction across a grid dimension that has any
revisiting and therefore all reductions need to happen in the rightmost
dimensions.
TPUs in Megacore configuration#
Some TPU chips have two TensorCores but appear as one device to JAX users. This is called “megacore”. The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but share HBM.
Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?
The basic idea is that if we have embarrassingly parallel dimensions in our
computation, we can split up those dimensions across the TensorCores.
We can indicate which dimensions are parallelizable by providing an
annotation to pallas_call
called dimension_semantics
.
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
should be a tuple of same length as grid
where each
entry is either "parallel"
or "arbitrary"
. "parallel"
indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. "arbitrary"
indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.
By specifying dimension_semantics
, we now execute the kernel
simultaneously on each TensorCore. Pallas will handle splitting up the grid
automatically.
Note that Megacore is only currently available on TPU
v4
and TPUv5p
. Supplyingdimension_semantics
annotations is a no-op on other platforms, but not specifying it will result in only one TensorCore being used (even if there are more than one available).
Conclusion#
In this guide we covered how to express TPU pipelines using pallas_call
,
grid
and BlockSpec
s. We covered how to express nested loops via a
multi-dimensional grid and how to handle reductions by initialize our
accumulators at the beginning of the reduction.
We also learned how to handle Megacore by adding annotations to the kernel.
Exercises left to the reader:
Try implementing a
sum
kernel that pipelines the other dimensions as wellAdd megacore support to the
add
kernel and thesum
kernel as well.