jax.experimental.pallas.BlockSpec#

jax.experimental.pallas.BlockSpec(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, *, memory_space=None)[源]#

指定数组在每次内核调用时应如何切片。

block_shape 是一个由 int | NoneBlockDim 类型(例如 pl.Elementpl.Squeezedpl.Blockedpl.BoundedSlice)组成的序列。这些类型中的每一种都指定了块维度的尺寸。None 用于指定从内核中挤出的维度。BlockDim 类型允许对维度的索引进行更精细的控制。index_map 需要返回一个与 block_shape 长度相同的元组,其中每个条目都取决于 BlockDim 的类型。

参见 BlockSpec,亦即如何分块输入 以及各个 BlockDim 类型的文档字符串以获取更多详细信息。

参数:
  • block_shape (序列[BlockDim | int | None] | None)

  • index_map (可调用对象[..., Any] | None)

  • indexing_mode (Any | None)

  • pipeline_mode (Buffered | None)

  • memory_space (Any | None)

__init__(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, *, memory_space=None)#
参数:
  • block_shape (序列[BlockDim | int | None] | None)

  • index_map (可调用对象[..., Any] | None)

  • indexing_mode (Any | None)

  • pipeline_mode (Buffered | None)

  • memory_space (Any | None)

返回类型:

方法

__init__([block_shape, index_map, ...])

replace(**changes)

返回一个新对象,其中指定字段已替换为新值。

to_block_mapping(origin, array_aval, *, ...)

属性

block_shape

index_map

indexing_mode

memory_space

pipeline_mode