jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef#
- class jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef(shape: 'tuple[int, int]', dtype: 'jnp.dtype' = <class 'jax.numpy.float32'>, _init: 'Any' = <jax._src.state.types.Uninitialized object at 0x7fa8a4457a60>)[源代码]#
-
- __init__(shape, dtype=<class 'jax.numpy.float32'>, _init=<jax._src.state.types.Uninitialized object>)#
方法
__init__
(shape[, dtype, _init])get_ref_aval
()init
(array)属性
shape