jax.experimental.pallas.program_id#

jax.experimental.pallas.program_id(axis)[源]#

返回沿着给定网格轴的内核执行位置。

例如,对于一个二维 grid 在内核执行中,对应于网格坐标 (1, 2)program_id(axis=0) 返回 1,而 program_id(axis=1) 返回 2

返回值是一个形状为 () 且数据类型为 int32 的数组。

参数:

axis (int) – 网格轴,用于计算程序在此轴上的位置。

返回类型:

jax.Array