jax.experimental.pallas 模块#
用于 Pallas 的模块,Pallas 是 JAX 用于自定义内核的扩展。
请参阅 Pallas 文档,网址为 https://jax.net.cn/en/latest/pallas.html。
后端#
类#
|
指定数组在每次内核调用时应如何被切片。 |
|
为 |
|
一个带有起始索引和大小的切片。 |
函数#
|
在某些输入上调用 Pallas 内核。 |
|
返回沿给定网格轴的内核执行位置。 |
|
返回沿给定轴的网格大小。 |
|
从给定索引加载数组并返回。 |
|
在给定索引处存储一个值。 |
|
交换给定索引处的值并返回旧值。 |
|
|
|
从 Pallas 内核内部打印值。 |
|
|
|
|
|
|
|
使用分配的引用调用函数并返回结果。 |
|
当满足条件时调用被装饰的函数。 |