jax.experimental.pallas.mosaic_gpu.kernel#

jax.experimental.pallas.mosaic_gpu.kernel(body, out_shape, *, scratch_shapes=(), compiler_params=None, **mesh_kwargs)[源码]#
参数:
  • body (Callable[..., None])

  • out_shape (object)

  • scratch_shapes (pallas_core.ScratchShapeTree)

  • compiler_params (pallas_core.CompilerParams | None)

  • mesh_kwargs (object)