Pallas:一种 JAX 内核语言

Pallas:一种 JAX 内核语言#

Pallas 是 JAX 的一个扩展,用于为 GPU 和 TPU 编写自定义内核。它旨在提供对生成代码的细粒度控制,同时结合 JAX 追踪和 jax.numpy API 的高级易用性。

本节包含使用 Pallas 的教程、指南和示例。另请参阅 jax.experimental.pallas 模块 API 文档。

警告

Pallas 目前处于实验阶段,变动频繁。请参阅 Pallas 更新日志 以了解近期变更。

您可以预见到会遇到错误和未实现的情况,例如,当高层级 JAX 概念的降级(lowering)需要模拟时,或者仅仅因为 Pallas 仍处于开发阶段。