jax.lax.composite#
- jax.lax.composite(decomposition, name, version=0)[source]#
复合操作,其语义由分解函数定义。
复合操作是一个高阶 JAX 函数,它封装了一个由其他 JAX 函数组成的(复合的)操作。操作的语义由
decomposition
函数实现。换句话说,定义的复合函数可以被其分解后的实现替换,而不会改变封装操作的语义。编译器可以通过其
name
、version
、kwargs
和 dtypes 识别特定的复合操作,从而发出更高效的代码,可能利用硬件特定的指令或优化。如果编译器无法识别复合操作,它将回退到编译decomposition
函数。考虑一个 “tangent” 复合操作。它的
decomposition
函数可以实现为sin(x) / cos(x)
。一个具有硬件感知能力的编译器可以识别 “tangent” 复合操作,并发出单个tangent
指令,而不是三个独立的指令 (sin
、divide
和cos
)。对于没有专用 tangent 支持的硬件,它将回退到编译分解。这对于保留在降低过程中可能会丢失的高级抽象非常有用,这使得在低级 IR 中更容易进行模式匹配。
- 参数:
- 返回:
返回一个复合函数。请注意,此函数的位置参数应解释为输入,而关键字参数应解释为操作的属性。任何以
None
值传递的关键字参数将从composite_attributes
中省略。- 返回类型:
Callable
示例
Tangent 内核
>>> def my_tangent_composite(x): ... return lax.composite( ... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent" ... )(x) >>> >>> pi = jnp.pi >>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(my_tangent_composite(x)) ... print(lax.tan(x)) [ 0. 1. -1. 0.] [ 0. 1. -1. 0.]
创建复合操作的推荐方法是通过装饰器。在函数签名中使用
/
和*
来显式声明位置参数和关键字参数>>> @partial(lax.composite, name="my.softmax") ... def my_softmax_composite(x, /, *, axis): ... return jax.nn.softmax(x, axis)