jax.custom_batching.custom_vmap.def_vmap#
- custom_vmap.def_vmap(vmap_rule)[source]#
为此 custom_vmap 函数定义 vmap 规则。
- 参数:
vmap_rule (Callable[..., tuple[Any, Any]]) – 一个实现 vmap 规则的函数。此函数应接受以下参数:(1) 一个整数
axis_size
作为其第一个参数,(2) 一个布尔 pytree,其结构与函数的输入相同,指定是否对每个参数进行批处理,以及 (3) 批处理的参数。它应返回一个元组,其中包含批处理的输出和一个布尔 pytree,其结构与输出相同,指定是否对每个输出元素进行批处理。有关示例,请参阅jax.custom_batching.custom_vmap()
的文档。- 返回:
此方法传递规则,返回未更改的
vmap_rule
。- 返回类型:
Callable[…, tuple[Any, Any]]