jax.lax.composite#

jax.lax.composite(decomposition, name, version=0)[source]#

复合操作,其语义由分解函数定义。

复合操作是一个高阶 JAX 函数,它封装了一个由其他 JAX 函数组成的(复合的)操作。操作的语义由 decomposition 函数实现。换句话说,定义的复合函数可以被其分解后的实现替换,而不会改变封装操作的语义。

编译器可以通过其 nameversionkwargs 和 dtypes 识别特定的复合操作,从而发出更高效的代码,可能利用硬件特定的指令或优化。如果编译器无法识别复合操作,它将回退到编译 decomposition 函数。

考虑一个 “tangent” 复合操作。它的 decomposition 函数可以实现为 sin(x) / cos(x)。一个具有硬件感知能力的编译器可以识别 “tangent” 复合操作,并发出单个 tangent 指令,而不是三个独立的指令 (sindividecos)。对于没有专用 tangent 支持的硬件,它将回退到编译分解。

这对于保留在降低过程中可能会丢失的高级抽象非常有用,这使得在低级 IR 中更容易进行模式匹配。

参数:
  • decomposition (Callable) – 实现复合操作语义的函数。

  • name (str) – 封装操作的名称。

  • version (int) – 可选整数,用于指示复合操作的语义更改。

返回:

返回一个复合函数。请注意,此函数的位置参数应解释为输入,而关键字参数应解释为操作的属性。任何以 None 值传递的关键字参数将从 composite_attributes 中省略。

返回类型:

Callable

示例

Tangent 内核

>>> def my_tangent_composite(x):
...   return lax.composite(
...     lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
...   )(x)
>>>
>>> pi = jnp.pi
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(my_tangent_composite(x))
...   print(lax.tan(x))
[ 0.  1. -1.  0.]
[ 0.  1. -1.  0.]

创建复合操作的推荐方法是通过装饰器。在函数签名中使用 /* 来显式声明位置参数和关键字参数

>>> @partial(lax.composite, name="my.softmax")
... def my_softmax_composite(x, /, *, axis):
...   return jax.nn.softmax(x, axis)