Pallas:一种 JAX 内核语言#
Pallas 是 JAX 的一个扩展,用于为 GPU 和 TPU 编写自定义内核。它旨在提供对生成代码的细粒度控制,同时结合 JAX 追踪和 jax.numpy API 的高级易用性。
本节包含使用 Pallas 的教程、指南和示例。另请参阅 jax.experimental.pallas 模块 API 文档。
警告
Pallas 目前处于实验阶段,变动频繁。请参阅 Pallas 更新日志 以了解近期变更。
您可以预见到会遇到错误和未实现的情况,例如,当高层级 JAX 概念的降级(lowering)需要模拟时,或者仅仅因为 Pallas 仍处于开发阶段。
指南
TPU 后端指南
Mosaic GPU 后端指南
其他