jax.custom_batching.custom_vmap.def_vmap#
- custom_vmap.def_vmap(vmap_rule)[source]#
为此 `custom_vmap` 函数定义 `vmap` 规则。
- 参数:
vmap_rule (Callable[..., tuple[Any, Any]]) – 一个实现 vmap 规则的函数。该函数应接受以下参数:(1) 一个整数
axis_size
作为其第一个参数,(2) 一个布尔值的 pytree,其结构与函数的输入相同,用于指定每个参数是否被批处理,以及 (3) 被批处理的参数。它应返回一个元组,其中包含批处理后的输出以及一个布尔值的 pytree,其结构与输出相同,用于指定每个输出元素是否被批处理。请参阅jax.custom_batching.custom_vmap()
的文档以获取一些示例。- 返回:
此方法将规则直接传递,并返回
vmap_rule
不变。- 返回类型:
Callable[…, tuple[Any, Any]]