jax.custom_batching.custom_vmap.def_vmap#

custom_vmap.def_vmap(vmap_rule)[源代码]#

为此 `custom_vmap` 函数定义 `vmap` 规则。

参数:

vmap_rule (Callable[..., tuple[Any, Any]]) – 实现 vmap 规则的函数。此函数应接受以下参数:(1) 第一个参数是一个整数 axis_size,(2) 一个与函数输入结构相同的布尔值 Pytree,指定每个参数是否是批量的,(3) 批量化的参数。它应返回一个元组,其中包含批量化的输出和一个与输出结构相同的布尔值 Pytree,指定每个输出元素是否是批量的。请参阅 jax.custom_batching.custom_vmap() 的文档以获取一些示例。

返回:

此方法将规则传递下去,返回 vmap_rule 而不作更改。

返回类型:

Callable[…, tuple[Any, Any]]