jax.experimental.pallas.BlockSpec#
- 类 jax.experimental.pallas.BlockSpec(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, *, memory_space=None)[源]#
指定数组在每次内核调用时应如何切片。
block_shape 是一个由 int | None 或 BlockDim 类型(例如 pl.Element、pl.Squeezed、pl.Blocked、pl.BoundedSlice)组成的序列。这些类型中的每一种都指定了块维度的尺寸。None 用于指定从内核中挤出的维度。BlockDim 类型允许对维度的索引进行更精细的控制。index_map 需要返回一个与 block_shape 长度相同的元组,其中每个条目都取决于 BlockDim 的类型。
参见 BlockSpec,亦即如何分块输入 以及各个 BlockDim 类型的文档字符串以获取更多详细信息。
- 参数:
block_shape (序列[BlockDim | int | None] | None)
index_map (可调用对象[..., Any] | None)
indexing_mode (Any | None)
pipeline_mode (Buffered | None)
memory_space (Any | None)
- __init__(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, *, memory_space=None)#
- 参数:
block_shape (序列[BlockDim | int | None] | None)
index_map (可调用对象[..., Any] | None)
indexing_mode (Any | None)
pipeline_mode (Buffered | None)
memory_space (Any | None)
- 返回类型:
无
方法
__init__
([block_shape, index_map, ...])replace
(**changes)返回一个新对象,其中指定字段已替换为新值。
to_block_mapping
(origin, array_aval, *, ...)属性
block_shape
index_map
indexing_mode
memory_space
pipeline_mode