Pallas 发行说明#
这是 jax.experimental.pallas 特有的更改列表。有关 JAX 的整体发行说明,请参见 此处。
未发布#
移除
移除了先前已弃用的
jax.experimental.pallas.tpu.TPUCompilerParams、jax.experimental.pallas.tpu.TPUMemorySpace、jax.experimental.pallas.tpu.TritonCompilerParams。
随 jax 0.7.1 发布#
新功能
pltpu.make_async_remote_copy和pltpu.semaphore_signal的device_id参数现在允许用户传入一个仅指定通信轴上设备索引的字典,而不是完整的坐标。它还支持 TPU 核心 ID 索引。jax.debug.print现在可在 Pallas 内核中使用,并且是推荐的打印方式。
弃用
pl.atomic_*API 已移至jax.experimental.pallas.triton。通过jax.experimental.pallas访问它们已被弃用。pl.load和pl.store已弃用。请改用索引或特定后端加载/存储 API。
随 jax 0.7.0 发布#
新功能
添加了一个新的装饰器
jax.experimental.pallas.loop(),允许将无状态循环写成函数。为
jax.experimental.pallas.tpu.emit_pipeline()添加了新的多缓冲和前瞻功能。输入缓冲区现在可以进行多缓冲,拥有超过 2 个缓冲区,并支持前瞻选项来获取比立即下一个迭代任意数量网格迭代的块。此外,管道状态现在可以保存在寄存器中以减少标量内存使用。
弃用
jax.experimental.pallas.triton.TritonCompilerParams已重命名为jax.experimental.pallas.triton.CompilerParams。旧名称已弃用,将在未来版本中移除。jax.experimental.pallas.tpu.TPUCompilerParams和jax.experimental.pallas.tpu.TPUMemorySpace已重命名为jax.experimental.pallas.tpu.CompilerParams和jax.experimental.pallas.tpu.MemorySpace。旧名称已弃用,将在未来版本中移除。
随 jax 0.6.1 发布#
移除
移除了先前已弃用的
jax.experimental.pallas.gpu。要使用 Triton 后端,请导入jax.experimental.pallas.triton。
更改
jax.experimental.pallas.BlockSpec()现在除了整数/None 外,还接受block_shape中的特殊类型。indexing_mode已移除。要实现“非块化”,请为每个需要非块化索引的条目将pl.Element(size)传递给block_shape。jax.experimental.pallas.pallas_call()现在要求compiler_params是一个特定于后端的 dataclass,而不是参数到值的映射。jax.experimental.pallas.debug_check()现在同时支持 TPU 和 Mosaic GPU。以前,此功能仅在 TPU 上受支持,并且需要使用jax.experimental.checkify中的 API。请注意,除非设置了jax.experimental.pallas.enable_debug_checks,否则不会执行调试检查。
随 jax 0.5.0 发布#
新功能
为 TPU 上的
jax.experimental.pallas.debug_print()添加了向量支持。
随 jax 0.4.37 发布#
新功能
为 Triton 后端上的
dot降低添加了DotAlgorithmPreset精度参数。
随 jax 0.4.36 发布(2024 年 12 月 6 日)#
随 jax 0.4.35 发布(2024 年 10 月 22 日)#
移除
移除了先前已弃用的别名
jax.experimental.pallas.tpu.CostEstimate和jax.experimental.tpu.run_scoped()。两者现在都可以在jax.experimental.pallas中找到。
新功能
添加了一个成本估算工具
pl.estimate_cost(),用于从 JAX 参考函数自动构建内核成本估算。
随 jax 0.4.34 发布(2024 年 10 月 4 日)#
更改
jax.experimental.pallas.debug_print()不再要求所有参数都为标量。参数的限制是后端特定的:非标量参数当前仅在 GPU 上支持,当使用 Triton 时。jax.experimental.pallas.BlockSpec不再支持先前已弃用的参数顺序,其中index_map在block_shape之前。
弃用
出于与
jax.experimental.pallas.mosaic_gpu区分歧义的目的,jax.experimental.pallas.gpu子模块已被弃用。要使用 Triton 后端,请导入jax.experimental.pallas.triton。
新功能
jax.experimental.pallas.pallas_call()现在接受scratch_shapes,这是一个 PyTree,用于指定内核所需的特定于后端的临时对象,例如缓冲区、同步原语等。当使用
pltpu.enable_runtime_assert(True)上下文管理器调用 pallas_call 时,checkify.check()现在可用于插入运行时断言。
随 jax 0.4.33 发布(2024 年 9 月 16 日)#
随 jax 0.4.32 发布(2024 年 9 月 11 日)#
更改
不允许内核函数捕获常量。相反,所有需要的数组都必须作为输入传递,并带有正确的块规范 (#22746)。
新功能
改进了索引映射函数签名错误的错误消息,包括索引映射的名称和源位置。
随 jax 0.4.31 发布(2024 年 7 月 29 日)#
更改
jax.experimental.pallas.BlockSpec现在期望block_shape在index_map之前传递。旧的参数顺序已弃用,将在未来版本中移除。jax.experimental.pallas.GridSpec不再具有in_specs_tree和out_specs_tree字段,并且in_specs和out_specs树现在将值存储为 BlockSpec 的 Pytree。先前,in_specs和out_specs是展平的 (#22552)。从
jax.experimental.pallas.GridSpec中移除了compute_index方法,因为它属于私有方法。类似地,从BlockSpec中移除了get_grid_mapping和unzip_dynamic_bounds(#22593)。修复了解释模式以与涉及填充的 BlockSpec 一起工作 (#22275)。解释模式下的填充将使用 NaN,以帮助调试越界错误,但这在运行自定义内核模式时不存在,不应依赖。
以前可以导入许多旨在作为私有的 API,作为
jax.experimental.pallas.pallas。现在已不再可能。
新功能
添加了 BlockSpec 的文档:网格和 BlockSpecs。
改进了
jax.experimental.pallas.pallas_call()API 的错误消息。为 Pallas TPU 自定义内核添加了形状多态性的初始支持
添加了 TPU 对 checkify 的支持。(#22480)
当块大小不符合 TPU 要求时,添加了更清晰的错误消息。以前,错误来自 Mosaic 后端,没有有用的 Python 堆栈跟踪。
添加了对具有 1D 块的 TPU 降低的支持,并放宽了块大小的要求,至少要有 2 个维度:最后 2 个维度必须分别被 8 和 128 整除,除非它们跨越了相应的数组维度的整个范围。以前,允许块维度跨越整个数组,但仅当最后两个维度的块维度小于 8 和 128 时。
随 JAX 0.4.30 发布(2024 年 6 月 18 日)#
新功能
在解释模式下为
jax.experimental.pallas.pallas_call()添加了 checkify 支持 (#21862)。为 TPU 内核改进了 PRNG 密钥支持 (#21773)。