Pallas 发行说明#

这是 jax.experimental.pallas 特有的更改列表。有关 JAX 的整体发行说明,请参见 此处

未发布#

  • 移除

    • 移除了先前已弃用的 jax.experimental.pallas.tpu.TPUCompilerParamsjax.experimental.pallas.tpu.TPUMemorySpacejax.experimental.pallas.tpu.TritonCompilerParams

随 jax 0.7.1 发布#

  • 新功能

    • pltpu.make_async_remote_copypltpu.semaphore_signaldevice_id 参数现在允许用户传入一个仅指定通信轴上设备索引的字典,而不是完整的坐标。它还支持 TPU 核心 ID 索引。

    • jax.debug.print 现在可在 Pallas 内核中使用,并且是推荐的打印方式。

  • 弃用

随 jax 0.7.0 发布#

  • 新功能

    • 添加了一个新的装饰器 jax.experimental.pallas.loop(),允许将无状态循环写成函数。

    • jax.experimental.pallas.tpu.emit_pipeline() 添加了新的多缓冲和前瞻功能。输入缓冲区现在可以进行多缓冲,拥有超过 2 个缓冲区,并支持前瞻选项来获取比立即下一个迭代任意数量网格迭代的块。此外,管道状态现在可以保存在寄存器中以减少标量内存使用。

  • 弃用

    • jax.experimental.pallas.triton.TritonCompilerParams 已重命名为 jax.experimental.pallas.triton.CompilerParams。旧名称已弃用,将在未来版本中移除。

    • jax.experimental.pallas.tpu.TPUCompilerParamsjax.experimental.pallas.tpu.TPUMemorySpace 已重命名为 jax.experimental.pallas.tpu.CompilerParamsjax.experimental.pallas.tpu.MemorySpace。旧名称已弃用,将在未来版本中移除。

随 jax 0.6.1 发布#

  • 移除

  • 更改

    • jax.experimental.pallas.BlockSpec() 现在除了整数/None 外,还接受 block_shape 中的特殊类型。indexing_mode 已移除。要实现“非块化”,请为每个需要非块化索引的条目将 pl.Element(size) 传递给 block_shape

    • jax.experimental.pallas.pallas_call() 现在要求 compiler_params 是一个特定于后端的 dataclass,而不是参数到值的映射。

    • jax.experimental.pallas.debug_check() 现在同时支持 TPU 和 Mosaic GPU。以前,此功能仅在 TPU 上受支持,并且需要使用 jax.experimental.checkify 中的 API。请注意,除非设置了 jax.experimental.pallas.enable_debug_checks,否则不会执行调试检查。

随 jax 0.5.0 发布#

随 jax 0.4.37 发布#

  • 新功能

    • 为 Triton 后端上的 dot 降低添加了 DotAlgorithmPreset 精度参数。

随 jax 0.4.36 发布(2024 年 12 月 6 日)#

随 jax 0.4.35 发布(2024 年 10 月 22 日)#

  • 移除

    • 移除了先前已弃用的别名 jax.experimental.pallas.tpu.CostEstimatejax.experimental.tpu.run_scoped()。两者现在都可以在 jax.experimental.pallas 中找到。

  • 新功能

    • 添加了一个成本估算工具 pl.estimate_cost(),用于从 JAX 参考函数自动构建内核成本估算。

随 jax 0.4.34 发布(2024 年 10 月 4 日)#

随 jax 0.4.33 发布(2024 年 9 月 16 日)#

随 jax 0.4.32 发布(2024 年 9 月 11 日)#

  • 更改

    • 不允许内核函数捕获常量。相反,所有需要的数组都必须作为输入传递,并带有正确的块规范 (#22746)。

  • 新功能

    • 改进了索引映射函数签名错误的错误消息,包括索引映射的名称和源位置。

随 jax 0.4.31 发布(2024 年 7 月 29 日)#

  • 更改

    • jax.experimental.pallas.BlockSpec 现在期望 block_shapeindex_map 之前传递。旧的参数顺序已弃用,将在未来版本中移除。

    • jax.experimental.pallas.GridSpec 不再具有 in_specs_treeout_specs_tree 字段,并且 in_specsout_specs 树现在将值存储为 BlockSpec 的 Pytree。先前,in_specsout_specs 是展平的 (#22552)。

    • jax.experimental.pallas.GridSpec 中移除了 compute_index 方法,因为它属于私有方法。类似地,从 BlockSpec 中移除了 get_grid_mappingunzip_dynamic_bounds (#22593)。

    • 修复了解释模式以与涉及填充的 BlockSpec 一起工作 (#22275)。解释模式下的填充将使用 NaN,以帮助调试越界错误,但这在运行自定义内核模式时不存在,不应依赖。

    • 以前可以导入许多旨在作为私有的 API,作为 jax.experimental.pallas.pallas。现在已不再可能。

  • 新功能

    • 添加了 BlockSpec 的文档:网格和 BlockSpecs

    • 改进了 jax.experimental.pallas.pallas_call() API 的错误消息。

    • 为 Pallas TPU 自定义内核添加了形状多态性的初始支持

    • 为 TPU 内核改进了 PRNG 密钥支持 (#21773)。
      (#22084).

    • 添加了 TPU 对 checkify 的支持。(#22480

    • 当块大小不符合 TPU 要求时,添加了更清晰的错误消息。以前,错误来自 Mosaic 后端,没有有用的 Python 堆栈跟踪。

    • 添加了对具有 1D 块的 TPU 降低的支持,并放宽了块大小的要求,至少要有 2 个维度:最后 2 个维度必须分别被 8 和 128 整除,除非它们跨越了相应的数组维度的整个范围。以前,允许块维度跨越整个数组,但仅当最后两个维度的块维度小于 8 和 128 时。

随 JAX 0.4.30 发布(2024 年 6 月 18 日)#