jax.experimental.pallas.num_programs#

jax.experimental.pallas.num_programs(axis)[源代码]#

返回给定轴上网格的大小。

参数:

axis (int)

返回类型:

int | jax.Array