jax.experimental.pallas.mosaic_gpu.as_torch_kernel#
- jax.experimental.pallas.mosaic_gpu.as_torch_kernel(fn)[源代码]#
使 Mosaic GPU 内核能够使用 PyTorch 张量进行调用。
- 参数:
fn – 一个调用 Mosaic GPU 内核的 JAX 函数。请注意,当前实现仅支持包含单个 Mosaic GPU 内核调用的函数,而不支持其他 JAX API 调用,例如来自
jax.numpy
的调用。- 返回:
一个包装函数,它接受 PyTorch 张量作为输入,并返回 PyTorch 张量作为输出。输出张量在与输入张量相同的设备上分配。
示例
@functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def add_kernel(x_ref, y_ref, o_ref): o_ref[...] = x_ref[...] + y_ref[...] x = torch.arange(128, dtype=torch.int32, device="cuda") y = x * x out = plgpu.as_torch_kernel(add_kernel)(x, y)