Foreign function interface (FFI)#

This tutorial requires JAX v0.4.31 or newer.

While a wide range of numerical operations can be easily and efficiently implemented using JAX’s built in jax.numpy and jax.lax interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a “foreign function interface” (FFI). This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs. That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost.

One point that should be taken into account when considering use of the FFI is that JAX doesn’t automatically know how to differentiate through foreign functions. This means that if you want to use JAX’s autodifferentiation capabilities alongside a foreign function, you’ll also need to provide an implementation of the relevant differentiation rules. We will discuss some possible approaches below, but it is important to call this limitation out right from the start!

JAX’s FFI support is provided in two parts:

  1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the openxla/xla project, and

  2. A Python front end, available in the jax.extend.ffi submodule.

In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.

This tutorial comes with two supplementary files:

A simple example#

To demonstrate the use of the FFI interface, we will implement a simple “root-mean-square (RMS)” normalization function. RMS normalization takes an array \(x\) with shape \((N,)\) and returns

\[ y_n = \frac{x_n}{\sqrt{\frac{1}{N}\sum_{n=1}^N {x_n}^2 + \epsilon}} \]

where \(\epsilon\) is a tuning parameter used for numerical stability.

This is a somewhat silly example, because it can be easily implemented using JAX as follows:

import jax
import jax.numpy as jnp


def rms_norm_ref(x, eps=1e-5):
  scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
  return x / scale

But, it’s just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand. We will use this reference implementation to test our FFI version below.

Backend code#

To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI. This isn’t meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following. So, here’s a simple implementation of RMS normalization in C++:

#include <cmath>
#include <cstdint>

float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
  float sm = 0.0f;
  for (int64_t n = 0; n < size; ++n) {
    sm += x[n] * x[n];
  }
  float scale = 1.0f / std::sqrt(sm / float(size) + eps);
  for (int64_t n = 0; n < size; ++n) {
    y[n] = x[n] * scale;
  }
  return scale;
}

and, for our example, this is the function that we want to expose to JAX via the FFI.

C++ interface#

To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the xla/ffi/api directory of the XLA project. For more information about this interface, take a look at the XLA custom call documentation. The full source listing can be downloaded here, but the key implementation details are reproduced here:

#include <functional>
#include <numeric>
#include <utility>

#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace ffi = xla::ffi;

// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
  auto dims = buffer.dimensions();
  if (dims.size() == 0) {
    return std::make_pair(0, 0);
  }
  return std::make_pair(buffer.element_count(), dims.back());
}

// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
                       ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
  auto [totalSize, lastDim] = GetDims(x);
  if (lastDim == 0) {
    return ffi::Error(ffi::ErrorCode::kInvalidArgument,
                      "RmsNorm input must be an array");
  }
  for (int64_t n = 0; n < totalSize; n += lastDim) {
    ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
  }
  return ffi::Error::Success();
}

// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::DataType::F32>>()  // x
        .Ret<ffi::Buffer<ffi::DataType::F32>>()  // y
);

Starting at the bottom, we’re using the XLA-provided macro XLA_FFI_DEFINE_HANDLER_SYMBOL to generate some boilerplate which will expand into a function called RmsNorm with the appropriate signature. But, the important stuff here is all in the call to ffi::Ffi::Bind(), where we define the input and output types, and the types of any parameters.

Then, in RmsNormImpl, we accept ffi::Buffer arguments which include information about the buffer shape, and pointers to the underlying data. In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis. GetDims is a helper function providing support for this batching behavior. We discuss this batching behavior in more detail below, but the general idea is that it can be useful to transparently handle batching in the left-most dimensions of the input arguments. In this case, we treat all but the last axis as batch dimensions, but other foreign functions may require a different number of non-batch dimensions.

Building and registering an FFI handler#

Now that we have our minimal FFI wrapper implemented, we need to expose this function (RmsNorm) to Python. In this tutorial, we compile RmsNorm into a shared library and load it using ctypes, but another common pattern is to use nanobind or pybind11 as discussed below.

To compile the shared library, we’re using CMake here, but you should be able to use your favorite build system without too much trouble. The full CMakeLists.txt can be downloaded here.

!cmake -DCMAKE_BUILD_TYPE=Release -B ffi/_build ffi
!cmake --build ffi/_build
!cmake --install ffi/_build
Hide code cell output
-- The CXX compiler identification is GNU 11.4.0
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/bin/python3.10 (found suitable version "3.10.14", minimum required is "3.8") found components: Interpreter Development.Module
-- XLA include directory: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include
-- Configuring done (1.0s)
-- Generating done (0.0s)
-- Build files have been written to: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/_build
[ 50%] Building CXX object CMakeFiles/rms_norm.dir/rms_norm.cc.o
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:483:68: warning: always_inline’ function might not be inlinable []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wattributes-Wattributes]8;;]
  483 | _ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
      |                                                             ^~~~~~~~~~~~
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:45,
                 from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_ExecutionStage)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:170:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  170 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_AttrType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:156:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  156 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_DataType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:143:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  143 | }
      | ^
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_ArgType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:551:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  551 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_RetType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:626:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  626 | }
      | ^
[100%] Linking CXX shared library librms_norm.so
[100%] Built target rms_norm
-- Install configuration: "Release"
-- Installing: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/librms_norm.so

With this compiled library in hand, we now need to register this handler with XLA via the register_ffi_target() function. This function expects our handler (a function pointer to the C++ function RmsNorm) to be wrapped in a PyCapsule. JAX provides a helper function pycapsule() to help with this:

import ctypes
from pathlib import Path
import jax.extend as jex

path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jex.ffi.register_ffi_target(
    "rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")

Tip

If you’re familiar with the legacy “custom call” API, it’s worth noting that you can also use register_ffi_target() to register a custom call target by manually specifying the keyword argument api_version=0. The default api_version for register_ffi_target() is 1, the new “typed” FFI API that we’re using here.

An alternative approach: A common alternative pattern for exposing handlers to Python is to use nanobind or pybind11 to define a tiny Python extension which can be imported. For our example here, the nanobind code would be:

#include <type_traits>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"

namespace nb = nanobind;

template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
  // This check is optional, but it can be helpful for avoiding invalid handlers.
  static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
                "Encapsulated function must be and XLA FFI handler");
  return nb::capsule(reinterpret_cast<void *>(fn));
}

NB_MODULE(rms_norm, m) {
  m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
}

Then, in Python we can register this handler using:

# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib

jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")

Frontend code#

Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the ffi_call() function:

import numpy as np


def rms_norm(x, eps=1e-5):
  # We only implemented the `float32` version of this function, so we start by
  # checking the dtype. This check isn't strictly necessary because type
  # checking is also performed by the FFI when decoding input and output
  # buffers, but it can be useful to check types in Python to raise more
  # informative errors.
  if x.dtype != jnp.float32:
    raise ValueError("Only the float32 dtype is implemented by rms_norm")

  # In this case, the output of our FFI function is just a single array with the
  # same shape and dtype as the input. We discuss a case with a more interesting
  # output type below.
  out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

  return jex.ffi.ffi_call(
    # The target name must be the same string as we used to register the target
    # above in `register_custom_call_target`
    "rms_norm",
    out_type,
    x,
    # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
    # the attribute `eps`. Our FFI function expects this to have the C++ `float`
    # type (which corresponds to numpy's `float32` type), and it must be a
    # static parameter (i.e. not a JAX array).
    eps=np.float32(eps),
    # The `vmap_method` parameter controls this function's behavior under `vmap`
    # as discussed below.
    vmap_method="broadcast_fullrank",
  )


# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)

This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting. Most of the heavy lifting here is done by the ffi_call() function, which tells JAX how to call the foreign function for a particular set of inputs. It’s important to note that the first argument to ffi_call() must be a string that matches the target name that we used when calling register_custom_call_target above.

Any attributes (defined using Attr in the C++ wrapper above) should be passed as keyword arguments to ffi_call(). Note that we explicitly cast eps to np.float32 because our FFI library expects a C float, and we can’t use jax.numpy here, because these parameters must be static arguments.

The vmap_method argument to ffi_call() defines how this FFI call interacts with vmap() as described next.

Tip

If you are familiar with the earlier “custom call” interface, you might be surprised that we’re not passing the problem dimensions as parameters (batch size, etc.) to ffi_call(). In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the Buffer objects, we no longer need to compute this using Python when lowering. One major perk of this change is ffi_call() can support some simple vmap() semantics out of the box, as discussed below.

Batching with vmap#

ffi_call() supports some simple vmap() semantics out of the box using the vmap_method parameter. The docs for pure_callback() provide more details about the vmap_method parameter, and the same behavior applies to ffi_call().

The simplest vmap_method is "sequential". In this case, when vmapped, an ffi_call will be rewritten as a scan() with the ffi_call in the body. This implementation is general purpose, but it doesn’t parallelize very well. Many FFI calls provide more efficient batching behavior and, in some simple cases, the "broadcast" or "broadcast_fullrank" methods can be used to expose a better implementation.

In this case, since we only have one input argument, "broadcast" and "broadcast_fullrank" actually have the same behavior. The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions. Another way of saying this is that the result of calling ffi_call on the batched inputs is assumed to be equal to stacking the repeated application of ffi_call to each element in the batched input, roughly:

ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])

Tip

Note that things get a bit more complicated when we have multiple input arguments. For simplicity, we will use the "broadcast_fullrank" throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the "broadcast" method. The documentation for pure_callback() includes some examples of this

Our implementation of rms_norm has the appropriate semantics, and it supports vmap with vmap_method="broadcast_fullrank" out of the box:

np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)

We can inspect the jaxpr of the vmap() of rms_norm to confirm that it isn’t being rewritten using scan():

jax.make_jaxpr(jax.vmap(rms_norm))(x)
{ lambda ; a:f32[3,5]. let
    b:f32[3,5] = ffi_call[
      eps=1e-05
      has_side_effect=False
      result_avals=(ShapedArray(float32[3,5]),)
      target_name=rms_norm
      vectorized=Deprecated
      vmap_method=broadcast_fullrank
    ] a
  in (b,) }

Using vmap_method="sequential", vmapping a ffi_call will fall back on a jax.lax.scan() with the ffi_call in the body:

def rms_norm_sequential(x, eps=1e-5):
  return jex.ffi.ffi_call(
    "rms_norm",
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    x,
    eps=np.float32(eps),
    vmap_method="sequential",
  )


jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
{ lambda ; a:f32[3,5]. let
    b:f32[3,5] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[5]. let
          d:f32[5] = ffi_call[
            eps=1e-05
            has_side_effect=False
            result_avals=(ShapedArray(float32[5]),)
            target_name=rms_norm
            vectorized=Deprecated
            vmap_method=sequential
          ] c
        in (d,) }
      length=3
      linear=(False,)
      num_carry=0
      num_consts=0
      reverse=False
      unroll=1
    ] a
  in (b,) }

If your foreign function provides an efficient batching rule that isn’t supported by this simple vmap_method parameter, it might also be possible to define more flexible custom vmap rules using the experimental custom_vmap interface, but it’s worth also opening an issue describing your use case on the JAX issue tracker.

Differentiation#

Unlike with batching, ffi_call() doesn’t provide any default support for automatic differentiation (AD) of foreign functions. As far as JAX is concerned, the foreign function is a black box that can’t be inspected to determine the appropriate behavior when differentiated. Therefore, it is the ffi_call() user’s responsibility to define a custom derivative rule.

More details about custom derivative rules can be found in the custom derivatives tutorial, but the most common pattern used for implementing differentiation for foreign functions is to define a custom_vjp() which itself calls a foreign function. In this case, we actually define two new FFI calls:

  1. rms_norm_fwd returns two outputs: (a) the “primal” result, and (b) the “residuals” which are used in the backwards pass.

  2. rms_norm_bwd takes the residuals and the output co-tangents, and returns the input co-tangents.

We won’t get into the details of the RMS normalization backwards pass, but take a look at the C++ source code to see how these functions are implemented on the back end. The main point to emphasize here is that the “residual” computed has a different shape than the primal output, therefore, in the ffi_call() to res_norm_fwd, the output type has two elements with different shapes.

This custom derivative rule can be wired in as follows:

jex.ffi.register_ffi_target(
  "rms_norm_fwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jex.ffi.register_ffi_target(
  "rms_norm_bwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)


def rms_norm_fwd(x, eps=1e-5):
  y, res = jex.ffi.ffi_call(
    "rms_norm_fwd",
    (
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
    ),
    x,
    eps=np.float32(eps),
    vmap_method="broadcast_fullrank",
  )
  return y, (res, x)


def rms_norm_bwd(eps, res, ct):
  del eps
  res, x = res
  assert res.shape == ct.shape[:-1]
  assert x.shape == ct.shape
  return (
    jex.ffi.ffi_call(
      "rms_norm_bwd",
      jax.ShapeDtypeStruct(ct.shape, ct.dtype),
      res,
      x,
      ct,
    vmap_method="broadcast_fullrank",
    ),
  )


rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)

# Check that this gives the right answer when compared to the reference version
ct_y = jnp.ones_like(x)
np.testing.assert_allclose(
  jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5
)

At this point, we can use our new rms_norm function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like vmap() and grad(). One thing that this example doesn’t support is forward-mode AD (jax.jvp(), for example) since custom_vjp() is restricted to reverse-mode. JAX doesn’t currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please open an issue describing you use case if you hit this limitation in practice.

One other JAX feature that this example doesn’t support is higher-order AD. It would be possible to work around this by wrapping the res_norm_bwd function above in a jax.custom_jvp() or jax.custom_vjp() decorator, but we won’t go into the details of that advanced use case here.

FFI calls on a GPU#

So far, we have been interfacing only with foreign functions running on the CPU, but JAX’s FFI also supports calls to GPU code. Since this documentation page is automatically generated on a machine without access to a GPU, we can’t execute any GPU-specific examples here, but we will go over the key points.

When defining our FFI wrapper for CPU, the function signature that we used was:

ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
                       ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)

To update this to interface with a CUDA kernel, this signature becomes:

ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
                       ffi::Buffer<ffi::DataType::F32> x,
                       ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)

And the handler definition is updated to include a Ctx in its binding:

XLA_FFI_DEFINE_HANDLER(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::DataType::F32>>()  // x
        .Ret<ffi::Buffer<ffi::DataType::F32>>()  // y
);

Then, the RmsNormImpl can use the CUDA stream to launch CUDA kernels.

On the front end, the registration code would be updated to specify the appropriate platform:

jex.ffi.register_ffi_target(
  "rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)

Supporting multiple platforms#

To support running our rms_norm function on both GPU and CPU, we can combine our implementation above with the jax.lax.platform_dependent() function:

def rms_norm_cross_platform(x, eps=1e-5):
  assert x.dtype == jnp.float32
  out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

  def impl(target_name):
    return lambda x: jex.ffi.ffi_call(
      target_name,
      out_type,
      x,
      eps=np.float32(eps),
      vmap_method="broadcast_fullrank",
    )

  return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))


np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)

This version of the function will call the appropriate FFI target depending on the runtime platform.

As an aside, it may be interesting to note that while the jaxpr and lowered HLO both contain a reference to both FFI targets:

jax.make_jaxpr(rms_norm_cross_platform)(x)
{ lambda ; a:f32[3,5]. let
    b:i32[] = platform_index[has_default=False platforms=(('cpu',), ('cuda',))] 
    c:i32[] = clamp 0 b 1
    d:f32[3,5] = cond[
      branches=(
        { lambda ; e:f32[3,5]. let
            f:f32[3,5] = ffi_call[
              eps=1e-05
              has_side_effect=False
              result_avals=(ShapedArray(float32[3,5]),)
              target_name=rms_norm
              vectorized=Deprecated
              vmap_method=broadcast_fullrank
            ] e
          in (f,) }
        { lambda ; g:f32[3,5]. let
            h:f32[3,5] = ffi_call[
              eps=1e-05
              has_side_effect=False
              result_avals=(ShapedArray(float32[3,5]),)
              target_name=rms_norm_cuda
              vectorized=Deprecated
              vmap_method=broadcast_fullrank
            ] g
          in (h,) }
      )
    ] c a
  in (d,) }
print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())
module @jit_rms_norm_cross_platform attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x5xf32> {mhlo.layout_mode = "default"}) -> (tensor<3x5xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.clamp %c_0, %c, %c_1 : tensor<i32>
    %1 = "stablehlo.case"(%0) ({
      %2 = stablehlo.custom_call @rms_norm(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<3x5xf32>) -> tensor<3x5xf32>
      stablehlo.return %2 : tensor<3x5xf32>
    }, {
      %2 = stablehlo.custom_call @rms_norm_cuda(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<3x5xf32>) -> tensor<3x5xf32>
      stablehlo.return %2 : tensor<3x5xf32>
    }) : (tensor<i32>) -> tensor<3x5xf32>
    return %1 : tensor<3x5xf32>
  }
}

by the time the function is compiled, the appropriate FFI has been selected:

print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
HloModule jit_rms_norm_cross_platform, entry_computation_layout={(f32[3,5]{1,0})->f32[3,5]{1,0}}, frontend_attributes={xla.sdy.meshes={}}

ENTRY main.3 {
  Arg_0.1 = f32[3,5]{1,0} parameter(0)
  ROOT custom-call.2 = f32[3,5]{1,0} custom-call(Arg_0.1), custom_call_target="rms_norm", operand_layout_constraints={f32[3,5]{1,0}}, api_version=API_VERSION_TYPED_FFI
}

and there will be no runtime overhead to using jax.lax.platform_dependent(), and the compiled program won’t include any references to unavailable FFI targets.

Advanced topics#

This tutorial covers most of the basic steps that are required to get up and running with JAX’s FFI, but advanced use cases may require more features. We will leave these topics to future tutorials, but here are some possibly useful references:

  • Supporting multiple dtypes: In this tutorial’s example, we restricted to only support float32 inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for jax.extend.ffi.ffi_call() depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept ffi::AnyBuffer instead of ffi::Buffer<Dtype>. Then, the input buffer will include a element_type() method which can be used to define the appropriate dtype dispatching logic in the backend.

  • Sharding: When using JAX’s automatic data-dependent parallelism within jit(), FFI calls implemented using ffi_call() don’t have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use shard_map() or custom_partitioning().

  • Stateful foreign functions: It is also possible to use the FFI to wrap functions with associated state. There is a low-level example included in the XLA test suite, and a future tutorial will include more details.