jax.experimental.pallas
模块#
Pallas 模块,一个用于自定义内核的 JAX 扩展。
请参阅 Pallas 文档:https://jax.net.cn/en/latest/pallas.html。
后端#
类#
|
指定数组应如何为每个内核调用进行切片。 |
|
编码了 |
|
一个带有起始索引和大小的切片。 |
函数#
|
在某些输入上调用 Pallas 内核。 |
|
返回网格给定轴上的内核执行位置。 |
|
返回网格给定轴上的大小。 |
|
从给定索引加载数组并返回。 |
|
在给定索引处存储值。 |
|
交换给定索引处的值并返回旧值。 |
|
|
|
从 Pallas 内核内部打印值。 |
|
|
|
|
|
|
|
使用分配的引用调用函数并返回结果。 |
|