jax.custom_batching.custom_vmap.def_vmap#

custom_vmap.def_vmap(vmap_rule)[source]#

为此 custom_vmap 函数定义 vmap 规则。

参数:

vmap_rule (Callable[..., tuple[Any, Any]]) – 一个实现 vmap 规则的函数。此函数应接受以下参数:(1) 一个整数 axis_size 作为其第一个参数,(2) 一个布尔 pytree,其结构与函数的输入相同,指定是否对每个参数进行批处理,以及 (3) 批处理的参数。它应返回一个元组,其中包含批处理的输出和一个布尔 pytree,其结构与输出相同,指定是否对每个输出元素进行批处理。有关示例,请参阅 jax.custom_batching.custom_vmap() 的文档。

返回:

此方法传递规则,返回未更改的 vmap_rule

返回类型:

Callable[…, tuple[Any, Any]]