jax.custom_batching.custom_vmap.def_vmap#

custom_vmap.def_vmap(vmap_rule)[source]#

为此 `custom_vmap` 函数定义 `vmap` 规则。

参数:

vmap_rule (Callable[..., tuple[Any, Any]]) – 一个实现 vmap 规则的函数。该函数应接受以下参数:(1) 一个整数 axis_size 作为其第一个参数,(2) 一个布尔值的 pytree,其结构与函数的输入相同,用于指定每个参数是否被批处理,以及 (3) 被批处理的参数。它应返回一个元组,其中包含批处理后的输出以及一个布尔值的 pytree,其结构与输出相同,用于指定每个输出元素是否被批处理。请参阅 jax.custom_batching.custom_vmap() 的文档以获取一些示例。

返回:

此方法将规则直接传递,并返回 vmap_rule 不变。

返回类型:

Callable[…, tuple[Any, Any]]