Pallas:JAX 内核语言#

Pallas 是 JAX 的一个扩展,可用于为 GPU 和 TPU 编写自定义内核。它旨在提供对生成代码的精细控制,并结合 JAX 追踪和 jax.numpy API 的高级人体工程学。

本节包含 Pallas 的教程、指南和示例。另请参阅 jax.experimental.pallas 模块 API 文档。

警告

Pallas 尚处于实验阶段,并且经常变动。请参阅 Pallas 更新日志 以了解近期更改。

您可能会遇到错误和未实现的情况,例如,当高级 JAX 概念的降低需要仿真时,或者仅仅因为 Pallas 仍在开发中。