jax.experimental.pallas.mosaic_gpu.kernel# jax.experimental.pallas.mosaic_gpu.kernel(body, out_shape, *, scratch_shapes=(), compiler_params=None, **mesh_kwargs)[源码]# 参数: body (Callable[..., None]) out_shape (object) scratch_shapes (pallas_core.ScratchShapeTree) compiler_params (pallas_core.CompilerParams | None) mesh_kwargs (object)