jax.experimental.pallas 模块#

Pallas 模块,一个用于自定义内核的 JAX 扩展。

请参阅 Pallas 文档:https://jax.net.cn/en/latest/pallas.html

后端#

#

BlockSpec([block_shape, index_map, ...])

指定数组应如何为每个内核调用进行切片。

GridSpec([grid, in_specs, out_specs, ...])

编码了 jax.experimental.pallas.pallas_call() 的网格参数。

Slice(start, size[, stride])

一个带有起始索引和大小的切片。

函数#

pallas_call(kernel, out_shape, *[, ...])

在某些输入上调用 Pallas 内核。

program_id(axis)

返回网格给定轴上的内核执行位置。

num_programs(axis)

返回网格给定轴上的大小。

load(x_ref_or_view, idx, *[, mask, other, ...])

从给定索引加载数组并返回。

store(x_ref_or_view, idx, val, *[, mask, ...])

在给定索引处存储值。

swap(x_ref_or_view, idx, val, *[, mask, ...])

交换给定索引处的值并返回旧值。

broadcast_to(a, shape)

debug_print(fmt, *args)

从 Pallas 内核内部打印值。

dot(a, b[, trans_a, trans_b, allow_tf32, ...])

max_contiguous(x, values)

multiple_of(x, values)

run_scoped(f, *types[, collective_axes])

使用分配的引用调用函数并返回结果。

when(condition)