Pallas 更新日志#
这是针对 jax.experimental.pallas
的更改列表。有关 JAX 的完整更新日志,请参阅此处。
未发布#
弃用
pl.atomic_*
API 已移至jax.experimental.pallas.triton
。通过jax.experimental.pallas
访问它们已被弃用。
与 jax 0.7.0 一同发布#
新功能
添加了新装饰器
jax.experimental.pallas.loop()
,它允许将无状态循环编写为函数。为
jax.experimental.pallas.tpu.emit_pipeline()
添加了新的多缓冲和预读功能。输入缓冲区现在可以使用超过 2 个缓冲区进行多缓冲,并支持预读选项,以获取任意数量的网格迭代提前的块,而不是紧随其后的迭代。此外,流水线状态现在可以保存在寄存器中,以减少标量内存使用。
弃用
jax.experimental.pallas.triton.TritonCompilerParams
已更名为jax.experimental.pallas.triton.CompilerParams
。旧名称已被弃用,并将在未来版本中移除。jax.experimental.pallas.tpu.TPUCompilerParams
和jax.experimental.pallas.tpu.TPUMemorySpace
已更名为jax.experimental.pallas.tpu.CompilerParams
和jax.experimental.pallas.tpu.MemorySpace
。旧名称已被弃用,并将在未来版本中移除。
与 jax 0.6.1 一同发布#
移除
移除了先前已弃用的
jax.experimental.pallas.gpu
。要使用 Triton 后端,请导入jax.experimental.pallas.triton
。
更改
jax.experimental.pallas.BlockSpec()
现在除了block_shape
中的整数/None 外,还接受特殊类型。indexing_mode
已被移除。要实现“Unblocked”模式,请为每个需要不阻塞索引的条目,将pl.Element(size)
传递给block_shape
。jax.experimental.pallas.pallas_call()
现在要求compiler_params
是后端特定的数据类,而不是参数到值的映射。jax.experimental.pallas.debug_check()
现在在 TPU 和 Mosaic GPU 上都受支持。以前,此功能仅在 TPU 上受支持,并且需要使用jax.experimental.checkify
中的 API。请注意,除非设置了jax.experimental.pallas.enable_debug_checks
,否则不会执行调试检查。
与 jax 0.5.0 一同发布#
新功能
在 TPU 上为
jax.experimental.pallas.debug_print()
添加了向量支持。
与 jax 0.4.37 一同发布#
新功能
在 Triton 后端中,为
dot
降低添加了对DotAlgorithmPreset
精度参数的支持。
与 jax 0.4.36 一同发布(2024年12月6日)#
与 jax 0.4.35 一同发布(2024年10月22日)#
移除
移除了先前已弃用的别名
jax.experimental.pallas.tpu.CostEstimate
和jax.experimental.tpu.run_scoped()
。两者现在都可在jax.experimental.pallas
中使用。
新功能
添加了一个成本估算工具
pl.estimate_cost()
,用于从 JAX 参考函数自动构建内核成本估算。
与 jax 0.4.34 一同发布(2024年10月4日)#
更改
jax.experimental.pallas.debug_print()
不再要求所有参数都是标量。参数的限制是后端特定的:非标量参数目前仅在使用 Triton 时在 GPU 上受支持。jax.experimental.pallas.BlockSpec
不再支持先前已弃用的参数顺序,即index_map
位于block_shape
之前。
弃用
已弃用
jax.experimental.pallas.gpu
子模块,以避免与jax.experimental.pallas.mosaic_gpu
混淆。要使用 Triton 后端,请导入jax.experimental.pallas.triton
。
新功能
jax.experimental.pallas.pallas_call()
现在接受scratch_shapes
,这是一个 PyTree,用于指定内核所需的后端特定临时对象,例如缓冲区、同步原语等。现在可以使用
checkify.check()
在使用pltpu.enable_runtime_assert(True)
上下文管理器调用 pallas_call 时插入运行时断言。
与 jax 0.4.33 一同发布(2024年9月16日)#
与 jax 0.4.32 一同发布(2024年9月11日)#
更改
内核函数不允许闭包常量。相反,所有需要的数组必须作为输入传递,并带有适当的块规范 (#22746)。
新功能
改进了索引映射函数签名错误时的错误消息,以包含索引映射的名称和源代码位置。
与 jax 0.4.31 一同发布(2024年7月29日)#
更改
jax.experimental.pallas.BlockSpec
现在要求block_shape
传递在index_map
之前。旧的参数顺序已被弃用,并将在未来版本中移除。jax.experimental.pallas.GridSpec
不再包含in_specs_tree
和out_specs_tree
字段,并且in_specs
和out_specs
树现在将值存储为 BlockSpec 的 pytrees。以前,in_specs
和out_specs
是扁平化的 (#22552)。jax.experimental.pallas.GridSpec
的compute_index
方法已被移除,因为它是私有的。同样,get_grid_mapping
和unzip_dynamic_bounds
已从BlockSpec
中移除 (#22593)。修复了解释模式以使其与涉及填充的 BlockSpec 一起工作 (#22275)。解释模式下的填充将使用 NaN,以帮助调试越界错误,但此行为在自定义内核模式下不存在,不应依赖。
以前,许多旨在作为私有的 API 可以导入为
jax.experimental.pallas.pallas
。现在已不再可能。
新功能
添加了 BlockSpec 的文档:网格和块规范。
改进了
jax.experimental.pallas.pallas_call()
API 的错误消息。为 TPU 添加了
lax.shift_right_arithmetic
(#22279) 和lax.erf_inv
(#22310) 的降低规则。为 Pallas TPU 自定义内核添加了形状多态性的初步支持
(#22084).添加了对 checkify 的 TPU 支持。 (#22480)
当块大小不符合 TPU 要求时,添加了更清晰的错误消息。以前,错误来自 Mosaic 后端,并且没有有用的 Python 堆栈跟踪。
添加了对使用 1D 块进行 TPU 降低的支持,并放宽了至少具有 2 个维度的块大小要求:最后 2 个维度必须分别能被 8 和 128 整除,除非它们跨越整个相应的数组维度。以前,只有当最后两个维度中的块维度分别小于 8 和 128 时,才允许跨越整个数组的块维度。
与 JAX 0.4.30 一同发布(2024年6月18日)#
新功能
在解释模式下,为
jax.experimental.pallas.pallas_call()
添加了 checkify 支持 (#21862)。改进了对 TPU 内核的 PRNG 密钥支持 (#21773)。