jax.custom_batching.custom_vmap.def_vmap#
- custom_vmap.def_vmap(vmap_rule)[源代码]#
为此 `custom_vmap` 函数定义 `vmap` 规则。
- 参数:
vmap_rule (Callable[..., tuple[Any, Any]]) – 实现 vmap 规则的函数。此函数应接受以下参数:(1) 第一个参数是一个整数
axis_size,(2) 一个与函数输入结构相同的布尔值 Pytree,指定每个参数是否是批量的,(3) 批量化的参数。它应返回一个元组,其中包含批量化的输出和一个与输出结构相同的布尔值 Pytree,指定每个输出元素是否是批量的。请参阅jax.custom_batching.custom_vmap()的文档以获取一些示例。- 返回:
此方法将规则传递下去,返回
vmap_rule而不作更改。- 返回类型:
Callable[…, tuple[Any, Any]]