jax.experimental.pallas.program_id#

jax.experimental.pallas.program_id(axis)[源代码]#

返回内核执行在给定网格轴上的位置。

例如,对于内核执行中的 2D grid,对应于网格坐标 (1, 2)program_id(axis=0) 返回 1program_id(axis=1) 返回 2

返回的值是一个形状为 (),dtype 为 int32 的数组。

参数:

axis (int) – 网格的轴,沿该轴计数程序。

返回类型:

jax.Array