外部函数接口 (FFI)#
本教程需要 JAX v0.4.31 或更高版本。
虽然使用 JAX 内置的 jax.numpy
和 jax.lax
接口可以轻松高效地实现各种数值运算,但有时通过“外部函数接口”(FFI) 显式调用外部编译库可能会很有用。当某些运算先前已在优化的 C 或 CUDA 库中实现,并且直接使用 JAX 重新实现这些计算并不容易时,这尤其有用,但对于优化 JAX 程序的运行时或内存性能也很有用。话虽如此,FFI 通常应被视为最后的选择,因为后端使用的 XLA 编译器,或者提供更底层控制的 Pallas 内核语言,通常能以更低的开发和维护成本生成高性能代码。
在考虑使用 FFI 时,需要注意的一点是,JAX 不会自动知道如何对外部函数进行微分。这意味着,如果您想在 JAX 自动微分功能旁边使用外部函数,您还需要提供相关微分规则的实现。我们将在下面讨论一些可能的方法,但从一开始就指出这一限制非常重要!
JAX 的 FFI 支持分两个部分提供
来自 XLA 的一个仅头文件 C++ 库,该库自 v0.4.29 起已包含在 JAX 中,或者可以从 openxla/xla 项目获取,以及
一个 Python 前端,可在
jax.ffi
子模块中找到。
在本教程中,我们将通过一个简单的示例演示这两个组件的用法,然后讨论一些用于更复杂用例的底层扩展。我们从在 CPU 上演示 FFI 开始,并在下面讨论向 GPU 或多设备环境的泛化。
本示例以及其他一些更高级用例的端到端代码可在 GitHub 上的 JAX FFI 示例项目 examples/ffi
中找到。
由于我们将在本教程的末尾演示 FFI 调用如何分片,因此让我们先设置我们的环境,使其被 JAX 视为拥有多个 CPU。
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
一个简单的例子#
为了演示 FFI 接口的用法,我们将实现一个简单的“均方根 (RMS)”归一化函数。RMS 归一化函数接收一个形状为 \((N,)\) 的数组 \(x\),并返回
其中 \(\epsilon\) 是用于数值稳定性的调整参数。
这是一个有点愚蠢的例子,因为使用 JAX 可以轻松地实现它,如下所示
import jax
import jax.numpy as jnp
def rms_norm_ref(x, eps=1e-5):
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
return x / scale
但是,它足够不平凡,可以用于演示 FFI 的一些关键细节,同时仍然易于理解。我们将在下面使用此参考实现来测试我们的 FFI 版本。
后端代码#
首先,我们需要一个 C++ 中的 RMS 归一化实现,我们将通过 FFI 暴露它。这并不是为了特别追求高性能,但您可以想象,如果您在 C++ 库中有一个新的更好的 RMS 归一化实现,它可能会有如下的接口。所以,这是一个简单的 C++ RMS 归一化实现
#include <cmath>
#include <cstdint>
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}
对于我们的示例,这就是我们想通过 FFI 暴露给 JAX 的函数。
C++ 接口#
为了将我们的库函数暴露给 JAX 和 XLA,我们需要使用 xla/ffi/api
目录中 XLA 项目提供的仅头文件库编写一个薄包装器。有关此接口的更多信息,请查看 XLA custom call 文档。完整的源代码可以 在此处 下载,但关键实现细节在此处重现
#include <functional>
#include <numeric>
#include <utility>
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace ffi = xla::ffi;
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
从底部开始,我们使用 XLA 提供的宏 XLA_FFI_DEFINE_HANDLER_SYMBOL
来生成一些样板代码,这些代码会展开成一个名为 RmsNorm
的函数,具有适当的签名。但是,这里重要的部分都在 ffi::Ffi::Bind()
调用中,我们在其中定义输入和输出类型,以及任何参数的类型。
然后,在 RmsNormImpl
中,我们接受 ffi::Buffer
参数,其中包含有关缓冲区形状和底层数据指针的信息。在此实现中,我们将缓冲区的所有前导维度视为批次维度,并在最后一个轴上执行 RMS 归一化。 GetDims
是一个提供此批处理行为支持的辅助函数。我们将在 下面 更详细地讨论此批处理行为,但总体的想法是,透明地处理输入参数最左侧批次维度可能很有用。在这种情况下,我们将除最后一个轴之外的所有轴都视为批次维度,但其他外部函数可能需要不同数量的非批次维度。
构建和注册 FFI 处理程序#
现在我们已经实现了最小的 FFI 包装器,我们需要将其(RmsNorm
)函数暴露给 Python。在本教程中,我们将 RmsNorm
编译为共享库并使用 ctypes 加载它,但另一种常见模式是使用 nanobind 或 pybind11,如下所述。
要编译共享库,我们在这里使用 CMake,但您应该能够轻松使用您喜欢的构建系统。
!cmake -DCMAKE_BUILD_TYPE=Release -B ffi/_build ffi
!cmake --build ffi/_build
!cmake --install ffi/_build
有了这个编译好的库,我们现在需要通过 register_ffi_target()
函数将此处理程序注册到 XLA。此函数期望我们的处理程序(C++ 函数 RmsNorm
的函数指针)被包装在 PyCapsule
中。JAX 提供了一个辅助函数 pycapsule()
来帮助实现这一点。
import ctypes
from pathlib import Path
path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jax.ffi.register_ffi_target(
"rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
提示
如果您熟悉旧的“自定义调用”API,值得注意的是,您还可以使用 register_ffi_target()
通过手动指定关键字参数 api_version=0
来注册自定义调用目标。 register_ffi_target()
的默认 api_version
是 1
,即我们在此使用的新的“类型化”FFI API。
替代方法:一种常见的将处理程序暴露给 Python 的替代模式是使用 nanobind 或 pybind11 定义一个可以导入的微型 Python 扩展。对于我们的示例,nanobind 代码将是
#include <type_traits>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
namespace nb = nanobind;
template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
// This check is optional, but it can be helpful for avoiding invalid handlers.
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be and XLA FFI handler");
return nb::capsule(reinterpret_cast<void *>(fn));
}
NB_MODULE(rms_norm, m) {
m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
}
然后,在 Python 中,我们可以使用以下方法注册此处理程序
# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib
jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
前端代码#
现在我们已经注册了 FFI 处理程序,使用 ffi_call()
函数从 JAX 调用我们的 C++ 库非常简单
import numpy as np
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
call = jax.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
# In this case, the output of our FFI function is just a single array with
# the same shape and dtype as the input. We discuss a case with a more
# interesting output type below.
jax.ShapeDtypeStruct(x.shape, x.dtype),
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vmap_method="broadcast_all",
)
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return call(x, eps=np.float32(eps))
# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 32).reshape((8, 4))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
此代码单元格包含许多内联注释,应该可以解释大部分内容,但有几点值得特别强调。大部分繁重的工作由 ffi_call()
函数完成,它告诉 JAX 如何为特定的输入集调用外部函数。需要注意的是,ffi_call()
的第一个参数必须是一个字符串,该字符串与我们上面调用 register_custom_call_target
时使用的目标名称匹配。
任何属性(在上面的 C++ 包装器中使用 Attr
定义)都应作为关键字参数传递给 ffi_call()
。请注意,我们将 eps
显式转换为 np.float32
,因为我们的 FFI 库期望一个 C float
,并且我们不能在此处使用 jax.numpy
,因为这些参数必须是静态参数。
传递给 ffi_call()
的 vmap_method
参数定义了这个 FFI 调用如何与 vmap()
交互,如下一节所述。
提示
如果您熟悉早期的“自定义调用”接口,您可能会惊讶于我们没有将问题维度作为参数(批次大小等)传递给 ffi_call()
。在早期 API 中,后端没有接收有关输入数组元数据的机制,但由于 FFI 将维度信息包含在 Buffer
对象中,因此我们不再需要在降低时使用 Python 计算这些信息。这是一个主要优点:ffi_call()
可以开箱即用地支持一些简单的 vmap()
语义,如下文所述。
使用 vmap
进行批处理#
ffi_call()
使用 vmap_method
参数支持一些简单的 vmap()
语义。 pure_callback()
的文档提供了有关 vmap_method
参数的更多详细信息,ffi_call()
的行为也相同。
最简单的 vmap_method
是 "sequential"
。在这种情况下,当 vmap
时,ffi_call
将被重写为带有 ffi_call
作为主体的 scan()
。此实现是通用的,但并行化效果不佳。许多 FFI 调用提供更有效的批处理行为,在某些简单情况下,可以使用 "expand_dims"
或 "broadcast_all"
方法来公开更好的实现。
在这种情况下,由于我们只有一个输入参数,"expand_dims"
和 "broadcast_all"
的行为实际上是相同的。使用这些方法的具体假设是,外部函数知道如何处理批次维度。换句话说,对批次输入调用 ffi_call
的结果被假定为等于将 ffi_call
应用于批次输入中的每个元素的堆叠,大致如下:
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
提示
请注意,当有多个输入参数时,情况会变得更加复杂。为简单起见,本教程将始终使用 "broadcast_all"
,它保证所有输入都将被广播以具有相同的批次维度,但也可以实现一个外部函数来处理 "expand_dims"
方法。 pure_callback()
的文档包含了一些这方面的示例。
我们的 rms_norm
实现具有适当的语义,并且开箱即用地支持 vmap
,其 vmap_method="broadcast_all"
。
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
我们可以检查 vmap()
的 jaxpr 来确认它没有使用 scan()
进行重写。
jax.make_jaxpr(jax.vmap(rms_norm))(x)
{ lambda ; a:f32[8,4]. let
b:f32[8,4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[8,4]),)
target_name=rms_norm
vmap_method=broadcast_all
] a
in (b,) }
使用 vmap_method="sequential"
,对 ffi_call
进行 vmap
会回退到主体中带有 ffi_call
的 jax.lax.scan()
。
def rms_norm_sequential(x, eps=1e-5):
return jax.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="sequential",
)(x, eps=np.float32(eps))
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
{ lambda ; a:f32[8,4]. let
b:f32[8,4] = scan[
_split_transpose=False
jaxpr={ lambda ; c:f32[4]. let
d:f32[4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((0,),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((0,),)
result_avals=(ShapedArray(float32[4]),)
target_name=rms_norm
vmap_method=sequential
] c
in (d,) }
length=8
linear=(False,)
num_carry=0
num_consts=0
reverse=False
unroll=1
] a
in (b,) }
如果您的外部函数提供了这种简单的 vmap_method
参数不支持的高效批处理规则,那么也可以使用实验性的 custom_vmap
接口定义更灵活的自定义 vmap
规则,但最好也在 JAX 问题跟踪器 上就您的用例开放一个问题。
微分#
与批处理不同,ffi_call()
不提供任何默认的自动微分 (AD) 支持。对 JAX 而言,外部函数是一个黑盒,无法检查以确定微分时的适当行为。因此,定义自定义导数规则是 ffi_call()
用户的责任。
有关自定义导数规则的更多详细信息,请参阅 自定义导数教程,但用于实现外部函数微分的最常见模式是定义一个 custom_vjp()
,该函数本身调用一个外部函数。在这种情况下,我们实际上定义了两个新的 FFI 调用:
rms_norm_fwd
返回两个输出:(a)“原始”结果,以及(b)在反向传播中使用“残差”。rms_norm_bwd
接收残差和输出共切线,并返回输入共切线。
我们不深入探讨 RMS 归一化反向传播的细节,但请查看 C++ 源代码 以了解这些函数是如何在后端实现的。要在此强调的主要一点是,“残差”计算出的形状与原始输出不同,因此,在对 res_norm_fwd
的 ffi_call()
中,输出类型有两个形状不同的元素。
可以按如下方式连接此自定义导数规则:
jax.ffi.register_ffi_target(
"rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jax.ffi.register_ffi_target(
"rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)
def rms_norm_fwd(x, eps=1e-5):
y, res = jax.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return y, (res, x)
def rms_norm_bwd(eps, res, ct):
del eps
res, x = res
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jax.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",
)(res, x, ct),
)
rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
# Check that this gives the right answer when compared to the reference version
ct_y = jnp.ones_like(x)
np.testing.assert_allclose(
jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5
)
此时,我们可以透明地将新的 rms_norm
函数用于许多 JAX 应用程序,并且它将在 vmap()
和 grad()
等标准 JAX 函数变换下正确转换。此示例不支持的一种情况是前向模式 AD(例如 jax.jvp()
),因为 custom_vjp()
仅限于反向模式。JAX 目前没有公开 API 来同时自定义前向模式和反向模式 AD,但该 API 正在规划中,因此如果您在实践中遇到此限制,请 开放一个问题 并描述您的用例。
此示例不支持的另一个 JAX 功能是高阶 AD。可以通过将上面的 res_norm_bwd
函数包装在 jax.custom_jvp()
或 jax.custom_vjp()
装饰器中来解决此问题,但我们在此不详细介绍该高级用例。
GPU 上的 FFI 调用#
到目前为止,我们只与在 CPU 上运行的外部函数进行了交互,但 JAX 的 FFI 也支持调用 GPU 代码。由于此文档页面是在没有 GPU 访问的机器上自动生成的,因此我们无法在此处执行任何特定于 GPU 的示例,但我们将概述要点。
在定义 CPU 的 FFI 包装器时,我们使用的函数签名是
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
要更新此签名以与 CUDA 内核交互,签名如下:
ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
然后,处理程序定义将更新为在其绑定中包含一个 Ctx
。
XLA_FFI_DEFINE_HANDLER(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
然后,RmsNormImpl
可以使用 CUDA 流来启动 CUDA 内核。
在前端,注册代码将更新为指定适当的平台:
jax.ffi.register_ffi_target(
"rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)
支持多平台#
为了支持在 GPU 和 CPU 上运行我们的 rms_norm
函数,我们可以将上面的实现与 jax.lax.platform_dependent()
函数结合起来。
def rms_norm_cross_platform(x, eps=1e-5):
assert x.dtype == jnp.float32
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
def impl(target_name):
return lambda x: jax.ffi.ffi_call(
target_name,
out_type,
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)
此函数版本将根据运行时平台调用适当的 FFI 目标。
顺便说一句,值得注意的是,虽然 jaxpr 和降低后的 HLO 都包含对这两个 FFI 目标的引用,
jax.make_jaxpr(rms_norm_cross_platform)(x)
{ lambda ; a:f32[8,4]. let
b:i32[] = platform_index[platforms=(('cpu',), ('cuda',))]
c:f32[8,4] = cond[
branches=(
{ lambda ; d:f32[8,4]. let
e:f32[8,4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[8,4]),)
target_name=rms_norm
vmap_method=broadcast_all
] d
in (e,) }
{ lambda ; f:f32[8,4]. let
g:f32[8,4] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[8,4]),)
target_name=rms_norm_cuda
vmap_method=broadcast_all
] f
in (g,) }
)
branches_platforms=(('cpu',), ('cuda',))
] b a
in (c,) }
print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())
module @jit_rms_norm_cross_platform attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%0 = "stablehlo.case"(%c) ({
%c_0 = stablehlo.constant dense<0> : tensor<i32>
stablehlo.return %c_0 : tensor<i32>
}, {
%c_0 = stablehlo.constant dense<0> : tensor<i32>
stablehlo.return %c_0 : tensor<i32>
}) : (tensor<i32>) -> tensor<i32>
%1 = "stablehlo.case"(%0) ({
%2 = stablehlo.custom_call @rms_norm(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<8x4xf32>) -> tensor<8x4xf32>
stablehlo.return %2 : tensor<8x4xf32>
}) : (tensor<i32>) -> tensor<8x4xf32>
return %1 : tensor<8x4xf32>
}
}
但在函数编译后,将选择适当的 FFI,
print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
HloModule jit_rms_norm_cross_platform, entry_computation_layout={(f32[8,4]{1,0})->f32[8,4]{1,0}}
ENTRY main.1 {
x.1 = f32[8,4]{1,0} parameter(0)
ROOT ffi_call.1 = f32[8,4]{1,0} custom-call(x.1), custom_call_target="rms_norm", operand_layout_constraints={f32[8,4]{1,0}}, api_version=API_VERSION_TYPED_FFI
}
并且使用 jax.lax.platform_dependent()
将不会产生运行时开销,并且编译后的程序将不包含对不可用 FFI 目标的任何引用。
高级主题#
本教程涵盖了开始使用 JAX FFI 的大多数基本步骤,但高级用例可能需要更多功能。我们将把这些主题留给未来的教程,但这里有一些可能很有用的参考资料:
支持多种 dtype:在本教程的示例中,我们仅限于支持
float32
输入和输出,但许多用例需要支持多种不同的输入类型。处理此问题的一种方法是为所有支持的输入类型注册不同的 FFI 目标,然后使用 Python 根据输入类型选择适当的目标用于jax.ffi.ffi_call()
。但是,这种方法可能会根据支持的用例的组合而变得非常繁琐。因此,也可以定义 C++ 处理程序以接受ffi::AnyBuffer
而不是ffi::Buffer<Dtype>
。然后,输入缓冲区将包含一个element_type()
方法,该方法可用于在后端定义适当的 dtype 分派逻辑。有状态外部函数:也可以使用 FFI 来包装具有关联状态的函数。XLA 测试套件中包含一个 低级示例,并且未来的教程将包含更多详细信息。