jax.experimental.pallas.pallas_call#

jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None, metadata=None)[来源]#

在某些输入上调用 Pallas 内核。

详见 Pallas 快速入门

参数:
  • kernel (Callable[..., None]) – 内核函数,为每个输入和输出接收一个 Ref。Ref 的形状由对应的 in_specsout_specs 中的 block_shape 给出。

  • out_shape (Any) – 一个 PyTree,包含 jax.ShapeDtypeStruct,描述输出的形状和数据类型。

  • grid_spec (GridSpec | None) – 指定 gridin_specsout_specsscratch_shapes 的另一种方式。如果提供此参数,则不能同时提供其他这些参数。

  • grid (TupleGrid) – 迭代空间,表示为整数元组。内核将执行的次数与 prod(grid) 相同。详见 grid(又称循环中的内核)

  • in_specs (BlockSpecTree) – 一个 PyTree,包含 jax.experimental.pallas.BlockSpec,其结构与位置参数匹配。in_specs 的默认值指定所有输入的整个数组,例如 pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)。详见 BlockSpec(又称如何分块输入)

  • out_specs (BlockSpecTree) – 一个 PyTree,包含 jax.experimental.pallas.BlockSpec,其结构与输出匹配。out_specs 的默认值指定整个数组,例如 pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)。详见 BlockSpec(又称如何分块输入)

  • scratch_shapes (ScratchShapeTree) – 一个 PyTree,包含内核所需的后端特定临时对象,例如临时缓冲区、同步原语等。

  • input_output_aliases (Mapping[int, int]) – 一个字典,将某些输入的索引映射到与其别名的输出的索引。这些索引位于展平的输入和输出中。

  • debug (bool) – 如果为 True,Pallas 将在处理内核时打印其各种中间形式。

  • interpret (Any) – 将 pallas_call 解释为 jax.jit 对网格的扫描,其中扫描体是将内核降级为 JAX 函数的结果。这不需要 TPU 或 GPU,并且是唯一在 CPU 上运行 Pallas 内核的方式。这对于调试很有用。

  • name (str | None) – 如果存在,则指定此内核调用在调试和错误消息中使用的名称。我们将定义内核函数的文件和行附加到此名称,例如:{name} for kernel function {kernel_name} at {file}:{line}。如果缺失,则使用 {kernel_name} at {file}:{line}

  • compiler_params (Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None) – 可选的编译器参数。其值应为后端特定的数据类(jax.experimental.pallas.tpu.CompilerParamsjax.experimental.pallas.triton.CompilerParamsjax.experimental.pallas.mosaic_gpu.CompilerParams)或者是一个字典,将后端名称映射到相应的平台特定数据类。

  • backend (Backend | None) – 可选字符串字面量之一:"mosaic_tpu""triton""mosaic_gpu",用于确定要使用的后端。None 表示让 Pallas 决定。

  • metadata (dict[str, str] | None) – 一个可选的字典,包含有关内核的信息,这些信息将以 JSON 格式序列化到 HLO 中。可用于调试和分析。

  • cost_estimate (CostEstimate | None)

返回:

一个函数,可使用多个位置数组参数调用以调用 Pallas 内核。

返回类型:

Callable[…, Any]