外部函数接口 (FFI)#
本教程需要 JAX v0.4.31 或更高版本。
虽然使用 JAX 内置的 jax.numpy
和 jax.lax
接口可以轻松高效地实现各种数值操作,但有时通过“外部函数接口”(FFI) 显式调用外部编译库会很有用。当特定操作已在优化的 C 或 CUDA 库中实现,并且直接使用 JAX 重新实现这些计算并非易事时,这尤其有用;同时,它也可能有助于优化 JAX 程序的运行时或内存性能。尽管如此,FFI 通常应被视为最后的选择,因为作为后端存在的 XLA 编译器或提供更低级别控制的 Pallas 内核语言,通常能以更低的开发和维护成本生成高性能代码。
在考虑使用 FFI 时,需要注意的一点是,JAX 不会自动知道如何对外部函数进行微分。这意味着,如果想在外部函数旁边使用 JAX 的自动微分功能,还需要提供相关微分规则的实现。我们将在下面讨论一些可能的方法,但从一开始就指出这个限制非常重要!
JAX 的 FFI 支持分为两部分提供:
一个来自 XLA 的仅头文件的 C++ 库,自 v0.4.29 起作为 JAX 的一部分打包,或可从 openxla/xla 项目获取,以及
一个 Python 前端,可在
jax.ffi
子模块中使用。
在本教程中,我们将通过一个简单示例演示这两个组件的使用,然后讨论针对更复杂用例的一些低级扩展。我们首先介绍 CPU 上的 FFI,然后在下面讨论如何推广到 GPU 或多设备环境。
本示例以及其他一些更高级用例的端到端代码可在 GitHub 上的 JAX FFI 示例项目中找到,路径为 JAX 仓库中的 examples/ffi
。
由于本教程末尾将演示 FFI 调用如何分片,因此我们首先将环境设置为 JAX 视作拥有多个 CPU。
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
一个简单示例#
为了演示 FFI 接口的使用,我们将实现一个简单的“均方根 (RMS)”归一化函数。RMS 归一化接收一个形状为 \((N,)\) 的数组 \(x\) 并返回
其中 \(\epsilon\) 是一个用于数值稳定性的调优参数。
这是一个有点简单的例子,因为它可以使用 JAX 轻松实现,如下所示:
import jax
import jax.numpy as jnp
def rms_norm_ref(x, eps=1e-5):
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
return x / scale
但是,它又足够重要,可以用于演示 FFI 的一些关键细节,同时又易于理解。我们将在下面使用此参考实现来测试我们的 FFI 版本。
后端代码#
首先,我们需要一个 C++ 版本的 RMS 归一化实现,我们将通过 FFI 暴露它。这并非旨在实现特别高的性能,但你可以想象,如果你在 C++ 库中有一个新的更好的 RMS 归一化实现,它可能会有一个类似于以下的接口。因此,这里是一个简单的 C++ 版 RMS 归一化实现:
#include <cmath>
#include <cstdint>
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}
对于我们的示例,这就是我们希望通过 FFI 暴露给 JAX 的函数。
C++ 接口#
为了将我们的库函数暴露给 JAX 和 XLA,我们需要使用 XLA 项目的 xla/ffi/api
目录中仅头文件的库提供的 API 编写一个薄封装器。有关此接口的更多信息,请参阅 XLA 自定义调用文档。完整的源代码列表可从此处下载,但关键实现细节在此处重现:
#include <functional>
#include <numeric>
#include <utility>
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace ffi = xla::ffi;
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
从底部开始,我们使用 XLA 提供的宏 XLA_FFI_DEFINE_HANDLER_SYMBOL
来生成一些样板代码,这些代码将扩展为一个名为 RmsNorm
且具有适当签名的函数。但是,这里的重点都在于对 ffi::Ffi::Bind()
的调用,我们在其中定义了输入和输出类型以及任何参数的类型。
然后,在 RmsNormImpl
中,我们接受 ffi::Buffer
参数,这些参数包含缓冲区形状的信息以及指向底层数据的指针。在此实现中,我们将缓冲区的所有前导维度视为批处理维度,并对最后一个轴执行 RMS 归一化。GetDims
是一个辅助函数,提供对此批处理行为的支持。我们将在下面更详细地讨论此批处理行为,但其总体思想是,透明地处理输入参数最左侧维度中的批处理会很有用。在本例中,除了最后一个轴,我们将其余所有轴都视为批处理维度,但其他外部函数可能需要不同数量的非批处理维度。
构建并注册 FFI 处理器#
现在我们已经实现了最小的 FFI 封装器,我们需要将此函数 (RmsNorm
) 暴露给 Python。在本教程中,我们将 RmsNorm
编译为共享库并使用 ctypes 加载它,但另一种常见模式是使用 nanobind 或 pybind11,如下所述。
为了编译共享库,我们在这里使用 CMake,但您应该能够使用您喜欢的构建系统,而不会遇到太多麻烦。
!cmake -DCMAKE_BUILD_TYPE=Release -B ffi/_build ffi
!cmake --build ffi/_build
!cmake --install ffi/_build
有了这个编译好的库,我们现在需要通过 register_ffi_target()
函数向 XLA 注册此处理器。此函数期望我们的处理器(指向 C++ 函数 RmsNorm
的函数指针)被包装在一个 PyCapsule
中。JAX 提供了一个辅助函数 pycapsule()
来帮助实现此操作:
import ctypes
from pathlib import Path
path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jax.ffi.register_ffi_target(
"rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
提示
如果您熟悉旧的“自定义调用”API,值得注意的是,您也可以通过手动指定关键字参数 api_version=0
来使用 register_ffi_target()
注册自定义调用目标。 register_ffi_target()
的默认 api_version
是 1
,即我们在此处使用的新“类型化”FFI API。
另一种方法:将处理器暴露给 Python 的常见替代模式是使用 nanobind 或 pybind11 定义一个可以导入的小型 Python 扩展。对于我们这里的示例,nanobind 代码将是:
#include <type_traits>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
namespace nb = nanobind;
template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
// This check is optional, but it can be helpful for avoiding invalid handlers.
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be and XLA FFI handler");
return nb::capsule(reinterpret_cast<void *>(fn));
}
NB_MODULE(rms_norm, m) {
m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
}
然后,在 Python 中我们可以使用以下方式注册此处理器:
# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib
jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
前端代码#
现在我们已经注册了 FFI 处理器,使用 ffi_call()
函数从 JAX 调用我们的 C++ 库就变得简单了。
import numpy as np
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
call = jax.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
# In this case, the output of our FFI function is just a single array with
# the same shape and dtype as the input. We discuss a case with a more
# interesting output type below.
jax.ShapeDtypeStruct(x.shape, x.dtype),
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vmap_method="broadcast_all",
)
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return call(x, eps=np.float32(eps))
# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 32).reshape((8, 4))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
此代码单元包含许多行内注释,它们应该解释了此处发生的大部分情况,但有几点值得明确强调。这里的大部分繁重工作都是由 ffi_call()
函数完成的,它告诉 JAX 如何为特定输入集调用外部函数。重要的是要注意,ffi_call()
的第一个参数必须是一个字符串,它与我们上面调用 register_custom_call_target
时使用的目标名称匹配。
任何属性(在上面的 C++ 封装器中使用 Attr
定义的)都应作为关键字参数传递给 ffi_call()
。请注意,我们显式地将 eps
转换为 np.float32
,因为我们的 FFI 库期望一个 C float
,并且我们不能在这里使用 jax.numpy
,因为这些参数必须是静态参数。
传递给 ffi_call()
的 vmap_method
参数定义了此 FFI 调用如何与 vmap()
交互,如下所述。
提示
如果您熟悉早期的“自定义调用”接口,您可能会惊讶于我们没有将问题维度作为参数(批大小等)传递给 ffi_call()
。在早期 API 中,后端没有接收有关输入数组元数据的机制,但由于 FFI 包含 Buffer
对象中的维度信息,因此在降级时我们不再需要使用 Python 计算此信息。此更改的一个主要优点是 ffi_call()
可以开箱即用地支持一些简单的 vmap()
语义,如下所述。
使用 vmap
进行批处理#
ffi_call()
通过 vmap_method
参数开箱即用地支持一些简单的 vmap()
语义。pure_callback()
的文档提供了关于 vmap_method
参数的更多细节,并且相同的行为也适用于 ffi_call()
。
最简单的 vmap_method
是 "sequential"
。在这种情况下,当使用 vmap
时,一个 ffi_call
将被重写为一个以 ffi_call
作为主体的 scan()
。此实现是通用的,但并行化效果不佳。许多 FFI 调用提供了更高效的批处理行为,在一些简单的情况下,可以使用 "expand_dims"
或 "broadcast_all"
方法来暴露更好的实现。
在这种情况下,由于我们只有一个输入参数,"expand_dims"
和 "broadcast_all"
实际上行为相同。使用这些方法所需的具体假设是,外部函数知道如何处理批处理维度。另一种说法是,对批处理输入调用 ffi_call
的结果被假定等于将 ffi_call
对批处理输入中每个元素的重复应用进行堆叠,大致如下:
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
提示
请注意,当有多个输入参数时,情况会变得稍微复杂。为简单起见,本教程中我们将全程使用 "broadcast_all"
,它保证所有输入都将被广播以具有相同的批处理维度,但也可以实现一个外部函数来处理 "expand_dims"
方法。pure_callback()
的文档中包含一些这方面的示例:
我们实现的 rms_norm
具有适当的语义,并且开箱即用地支持 vmap
,使用 vmap_method="broadcast_all"
。
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
我们可以检查 rms_norm
的 vmap()
的 jaxpr,以确认它没有使用 scan()
重写。
jax.make_jaxpr(jax.vmap(rms_norm))(x)
{ lambda ; a:f32[8,4]. let
b:f32[8,4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[8,4]),)
target_name=rms_norm
vmap_method=broadcast_all
] a
in (b,) }
使用 vmap_method="sequential"
时,对 ffi_call
进行 vmap
操作将回退到使用 ffi_call
作为其主体的 jax.lax.scan()
。
def rms_norm_sequential(x, eps=1e-5):
return jax.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="sequential",
)(x, eps=np.float32(eps))
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
{ lambda ; a:f32[8,4]. let
b:f32[8,4] = scan[
_split_transpose=False
jaxpr={ lambda ; c:f32[4]. let
d:f32[4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((0,),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((0,),)
result_avals=(ShapedArray(float32[4]),)
target_name=rms_norm
vmap_method=sequential
] c
in (d,) }
length=8
linear=(False,)
num_carry=0
num_consts=0
reverse=False
unroll=1
] a
in (b,) }
如果您的外部函数提供了一种此简单 vmap_method
参数不支持的高效批处理规则,那么也可以使用实验性的 custom_vmap
接口定义更灵活的自定义 vmap
规则,但同时值得在 JAX 问题跟踪器上提出一个问题来描述您的用例。
微分#
与批处理不同,ffi_call()
不为外部函数的自动微分 (AD) 提供任何默认支持。就 JAX 而言,外部函数是一个黑盒,无法检查以确定微分时的适当行为。因此,定义自定义导数规则是 ffi_call()
用户的责任。
有关自定义导数规则的更多详细信息,请参阅自定义导数教程,但实现外部函数微分最常用的模式是定义一个 custom_vjp()
,它本身会调用一个外部函数。在这种情况下,我们实际上定义了两个新的 FFI 调用:
rms_norm_fwd
返回两个输出:(a) “原始”结果,以及 (b) 在反向传播中使用的“残差”。rms_norm_bwd
接收残差和输出协切线,并返回输入协切线。
我们不会深入探讨 RMS 归一化反向传播的细节,但请查看 C++ 源代码,了解这些函数在后端是如何实现的。这里要强调的重点是,计算出的“残差”与原始输出的形状不同,因此,在对 res_norm_fwd
的 ffi_call()
调用中,输出类型包含两个形状不同的元素。
此自定义导数规则可以按如下方式引入:
jax.ffi.register_ffi_target(
"rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jax.ffi.register_ffi_target(
"rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)
def rms_norm_fwd(x, eps=1e-5):
y, res = jax.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return y, (res, x)
def rms_norm_bwd(eps, res, ct):
del eps
res, x = res
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jax.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",
)(res, x, ct),
)
rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
# Check that this gives the right answer when compared to the reference version
ct_y = jnp.ones_like(x)
np.testing.assert_allclose(
jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5
)
此时,我们可以透明地将我们新的 rms_norm
函数用于许多 JAX 应用程序,并且它将在标准 JAX 函数变换(如 vmap()
和 grad()
)下进行适当的变换。本示例不支持的一个功能是前向模式 AD(例如 jax.jvp()
),因为 custom_vjp()
仅限于反向模式。JAX 目前没有公开用于同时自定义前向模式和反向模式 AD 的公共 API,但此类 API 已在规划中,因此如果您在实践中遇到此限制,请在 JAX 问题跟踪器上提出问题,描述您的用例。
本示例不支持的另一个 JAX 功能是高阶 AD。可以通过将上面的 res_norm_bwd
函数包装在 jax.custom_jvp()
或 jax.custom_vjp()
装饰器中来解决此问题,但我们在此处不讨论该高级用例的细节。
在 GPU 上进行 FFI 调用#
到目前为止,我们只与在 CPU 上运行的外部函数进行接口,但 JAX 的 FFI 也支持调用 GPU 代码。由于此文档页面是在无法访问 GPU 的机器上自动生成的,因此我们无法在此处执行任何 GPU 特定示例,但我们将概述关键点。
在为 CPU 定义 FFI 封装器时,我们使用的函数签名是:
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
要将其更新为与 CUDA 内核接口,此签名变为:
ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
并且处理器定义已更新,在其绑定中包含一个 Ctx
。
XLA_FFI_DEFINE_HANDLER(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
然后,RmsNormImpl
可以使用 CUDA 流来启动 CUDA 内核。
在前端,注册代码将更新以指定适当的平台:
jax.ffi.register_ffi_target(
"rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)
支持多平台#
为了支持在 GPU 和 CPU 上运行我们的 rms_norm
函数,我们可以将上述实现与 jax.lax.platform_dependent()
函数结合起来。
def rms_norm_cross_platform(x, eps=1e-5):
assert x.dtype == jnp.float32
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
def impl(target_name):
return lambda x: jax.ffi.ffi_call(
target_name,
out_type,
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)
此版本的函数将根据运行时平台调用相应的 FFI 目标。
顺便提一下,值得注意的是,虽然 jaxpr 和降级后的 HLO 都包含对两个 FFI 目标的引用
jax.make_jaxpr(rms_norm_cross_platform)(x)
{ lambda ; a:f32[8,4]. let
b:i32[] = platform_index[platforms=(('cpu',), ('cuda',))]
c:f32[8,4] = cond[
branches=(
{ lambda ; d:f32[8,4]. let
e:f32[8,4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[8,4]),)
target_name=rms_norm
vmap_method=broadcast_all
] d
in (e,) }
{ lambda ; f:f32[8,4]. let
g:f32[8,4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[8,4]),)
target_name=rms_norm_cuda
vmap_method=broadcast_all
] f
in (g,) }
)
branches_platforms=(('cpu',), ('cuda',))
] b a
in (c,) }
print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())
module @jit_rms_norm_cross_platform attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%0 = "stablehlo.case"(%c) ({
%c_0 = stablehlo.constant dense<0> : tensor<i32>
stablehlo.return %c_0 : tensor<i32>
}, {
%c_0 = stablehlo.constant dense<0> : tensor<i32>
stablehlo.return %c_0 : tensor<i32>
}) : (tensor<i32>) -> tensor<i32>
%1 = "stablehlo.case"(%0) ({
%2 = stablehlo.custom_call @rms_norm(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<8x4xf32>) -> tensor<8x4xf32>
stablehlo.return %2 : tensor<8x4xf32>
}) : (tensor<i32>) -> tensor<8x4xf32>
return %1 : tensor<8x4xf32>
}
}
但在函数编译时,已选择了适当的 FFI
print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
HloModule jit_rms_norm_cross_platform, entry_computation_layout={(f32[8,4]{1,0})->f32[8,4]{1,0}}
ENTRY main.3 {
Arg_0.1 = f32[8,4]{1,0} parameter(0)
ROOT custom-call.2 = f32[8,4]{1,0} custom-call(Arg_0.1), custom_call_target="rms_norm", operand_layout_constraints={f32[8,4]{1,0}}, api_version=API_VERSION_TYPED_FFI
}
并且使用 jax.lax.platform_dependent()
不会有运行时开销,编译后的程序也不会包含任何对不可用 FFI 目标的引用。
高级主题#
本教程涵盖了开始使用 JAX FFI 所需的大部分基本步骤,但高级用例可能需要更多功能。我们将把这些主题留待未来的教程中讨论,但这里有一些可能有用的参考资料:
支持多种数据类型:在本教程的示例中,我们仅支持
float32
输入和输出,但许多用例需要支持多种不同的输入类型。处理此问题的一种选择是为所有支持的输入类型注册不同的 FFI 目标,然后根据输入类型使用 Python 为jax.ffi.ffi_call()
选择适当的目标。但是,根据支持情况的组合,这种方法可能会很快变得难以管理。因此,也可以定义 C++ 处理器以接受ffi::AnyBuffer
而不是ffi::Buffer<Dtype>
。然后,输入缓冲区将包含一个element_type()
方法,该方法可用于在后端定义适当的数据类型分派逻辑。有状态外部函数:还可以使用 FFI 封装带有相关状态的函数。XLA 测试套件中包含一个低级示例,未来的教程将包含更多详细信息。