jax.tree_util.Partial#
- class jax.tree_util.Partial(func, *args, **kw)#
一个在 pytrees 中工作的 functools.partial 版本。
使用它进行部分函数求值,使其与 JAX 的转换兼容,例如,
Partial(func, *args, **kwargs)
。(您需要显式选择此行为,因为我们不想给 functools.partial 与普通函数闭包不同的语义。)
例如,这是一个以类似于
functools.partial
的方式使用Partial
的基本示例>>> import jax.numpy as jnp >>> add_one = Partial(jnp.add, 1) >>> add_one(2) Array(3, dtype=int32, weak_type=True)
Pytree 兼容性意味着生成的 partial 函数可以作为参数传递到转换后的 JAX 函数中,这对于标准的
functools.partial
函数是不可能的>>> from jax import jit >>> @jit ... def call_func(f, *args): ... return f(*args) ... >>> call_func(add_one, 2) Array(3, dtype=int32, weak_type=True)
将零个参数传递给
Partial
实际上会包装原始函数,使其成为 JAX 转换函数中的有效参数>>> call_func(Partial(jnp.add), 1, 2) Array(3, dtype=int32, weak_type=True)
如果我们直接将
jnp.add
传递给call_func
,则会导致TypeError
。请注意,如果在跟踪值的上下文中使用了
Partial
的结果,则在传递给部分求值函数时,会导致所有绑定参数都被跟踪>>> print_zero = Partial(print, 0) >>> print_zero() 0 >>> call_func(print_zero) JitTracer<~int32[]>
- __init__()#
方法
__init__
()属性
args
未来 partial 调用的参数元组
func
要在未来 partial 调用中使用的函数对象
keywords
未来 partial 调用的关键字参数字典