jax.lax.composite#

jax.lax.composite(decomposition, name, version=0)[source]#

复合操作,其语义由分解函数定义。

复合操作是一种高阶 JAX 函数,它封装了由其他 JAX 函数组成(复合)的操作。op 的语义由 decomposition 函数实现。换句话说,定义的复合函数可以用其分解实现替换,而不会改变封装操作的语义。

编译器可以通过其 nameversionkwargs 和 dtypes 来识别特定的复合操作,从而发出更高效的代码,可能利用特定于硬件的指令或优化。如果编译器无法识别复合操作,它会回退到编译 decomposition 函数。

考虑一个“正切”复合操作。它的 decomposition 函数可以实现为 sin(x) / cos(x)。一个具有硬件意识的编译器可以识别“正切”复合操作,并发出单个 tangent 指令,而不是三个单独的指令(sindividecos)。对于没有专用正切支持的硬件,它会回退到编译分解。

这对于保留在降低过程中会丢失的高级抽象非常有用,从而可以在低级 IR 中更容易地进行模式匹配。

参数:
  • decomposition (Callable) – 实现复合 op 语义的函数。

  • name (str) – 封装操作的名称。

  • version (int) – 可选的 int,用于指示复合操作的语义更改。

返回:

返回一个复合函数。请注意,此函数的位置参数应解释为输入,关键字参数应解释为 op 的属性。任何将 None 作为值传递的关键字参数将从 composite_attributes 中省略。

返回类型:

可调用对象

示例

正切内核

>>> def my_tangent_composite(x):
...   return lax.composite(
...     lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
...   )(x)
>>>
>>> pi = jnp.pi
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(my_tangent_composite(x))
...   print(lax.tan(x))
[ 0.  1. -1.  0.]
[ 0.  1. -1.  0.]

创建复合操作的推荐方法是通过装饰器。在函数签名中使用 /* 来明确指定位置参数和关键字参数

>>> @partial(lax.composite, name="my.softmax")
... def my_softmax_composite(x, /, *, axis):
...   return jax.nn.softmax(x, axis)