jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef#

class jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef(shape: 'tuple[int, int]', dtype: 'jnp.dtype' = <class 'jax.numpy.float32'>, _init: 'Any' = <jax._src.state.types.Uninitialized object at 0x7fa8a4457a60>)[源代码]#
参数:
  • shape (tuple[int, int])

  • dtype (jnp.dtype)

  • _init (Any)

__init__(shape, dtype=<class 'jax.numpy.float32'>, _init=<jax._src.state.types.Uninitialized object>)#
参数:
  • shape (tuple[int, int])

  • dtype (jnp.dtype)

  • _init (Any)

返回类型:

None

方法

__init__(shape[, dtype, _init])

get_ref_aval()

init(array)

属性

shape