JAX 内部原理:原语#

JAX 原语简介#

JAX 原语是 JAX 程序的基本计算单元。本文档解释了 JAX 原语必须支持的接口,以允许 JAX 执行其所有转换(这不是操作指南)。

例如,乘加运算可以使用低级 jax.lax.* 原语(类似于 XLA 运算符包装器)或 jax.extend.core.Primitive("multiply_add") 来实现,如下文进一步演示。

JAX 能够接受此类原语操作的序列,并通过其 Python 函数的可组合转换(例如 jax.jit()jax.grad()jax.vmap())来转换它们。JAX 以JAX 可追溯的方式实现这些转换。这意味着,当执行 Python 函数时,它应用于数据的唯一操作是:

  • 数据属性检查: 数据信息,例如形状或类型;或

  • JAX 原语: 这些是本教程中介绍的 JAX 特殊操作。

JAX 原语知道如何对具体数据值和抽象 JAX 值进行操作。JAX 可追溯函数可以被 JAX 使用抽象参数调用。例如,JAX 抽象值 — ShapedArray(float32[2,2]) — 捕获值的类型和形状,但不捕获具体数据值。

JAX 转换后的函数本身必须是 JAX 可追溯函数,以确保这些转换是可组合的,例如 jax.jit(jax.jacfwd(jax.grad(f)))

JAX 提供了预定义的原语,对应于大多数 XLA 操作,包括加法、matmul、sin、cos 和索引。

此外,JAX 还提供了用 JAX 原语实现的 NumPy 函数。这意味着使用 JAX 的 NumPy 实现的 Python 程序是 JAX 可追溯的,因此也是可转换的。其他库可以通过用 JAX 原语实现来使其成为 JAX 可追溯的。

此外,JAX 原语集是可扩展的,因此您可以定义一个新的原语来封装函数的行为,而不是用预定义的 JAX 原语重新实现函数。

考虑以下示例:您想为 JAX 添加对具有三个参数的乘加函数的支持,该函数在数学上定义为 multiply_add(x, y, z) = x * y + z。此函数对 3 个形状相同的浮点值张量进行操作,并逐点执行操作。您可以通过以下方式执行此操作

使用现有的 JAX 原语#

定义新函数的最简单方法是用 JAX 原语,或用本身使用 JAX 原语编写的其他函数来编写它们,例如,在 jax.lax() 模块中定义的那些

from jax import lax
from jax._src import api

def multiply_add_lax(x, y, z):
  """Implementation of multiply-add using the `jax.lax` primitives."""
  return lax.add(lax.mul(x, y), z)


def square_add_lax(a, b):
  """A square-add function using the newly defined multiply-add."""
  return multiply_add_lax(a, a, b)

print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax =  14.0
grad(square_add_lax) =  4.0

要了解 JAX 在内部如何使用原语,请添加一些助手来跟踪函数调用

#@title Helper functions (execute this cell)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _trace_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation

def _trace_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)

def trace(name):
  """A decorator for functions to trace arguments and results."""

  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])
    
    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res

    return func_wrapper

  return trace_func

class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False

您可以直接使用 jax.lax() 原语,也可以使用已经用这些原语编写的其他函数,例如 jax.numpy 中的函数

import jax.numpy as jnp
import numpy as np

@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)

@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)

print("\nNormal evaluation:")  
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0

Gradient evaluation:
call square_add_numpy(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ShapedArray(float32[], weak_type=True)>
|<- square_add_numpy = Traced<ShapedArray(float32[], weak_type=True)>
grad(square_add_numpy) =  4.0

请注意,在计算 jax.grad() 的过程中,JAX 使用特殊参数 ConcreteArray(...) 调用 square_add_numpymultiply_add_numpy(在本 colab 中进一步描述)。重要的是要记住,JAX 可追溯函数不仅必须能够对具体参数进行操作,而且还必须能够对 JAX 可能用于抽象函数执行的特殊抽象参数进行操作。

只要函数是用 JAX 原语编写的,JAX 可追溯性属性就能得到满足。

定义新的 JAX 原语#

添加对乘加支持的正确方法是用现有的 JAX 原语,如上所示。但是,为了演示 JAX 原语的工作原理,假设您想为 JAX 添加一个新的原语来实现乘加功能。

from jax.extend import core

multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.
  
  Note that the traced arguments must be passed as positional arguments
  to `bind`. 
  """
  return multiply_add_p.bind(x, y, z)

@trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b)

如果您尝试调用新定义的函数,您将收到错误,因为您尚未告诉 JAX 任何关于新原语的语义的信息。

with expectNotImplementedError():
  square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1105/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_1105/1393342955.py", line 48, in func_wrapper
    res = func(*args)
  File "/tmp/ipykernel_1105/2637569133.py", line 17, in square_add_prim
    return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented

原始求值规则#

@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.

  Args:
    x, y, z: The concrete arguments of the primitive. Will only be called with 
      concrete values.

  Returns:
    the concrete result of the primitive.
  """
  # Note: you can use the ordinary (non-JAX) NumPy, which is not JAX-traceable.
  return np.add(np.multiply(x, y), z)

# Now, register the primal implementation with JAX:
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0

使用 jit 时会发生什么#

现在,如果您尝试使用 jit,您将收到 NotImplementedError

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1105/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented

抽象求值规则#

为了对函数进行 JIT 编译,以及为了进行其他转换,JAX 首先仅使用参数的形状和类型对其进行抽象求值。此抽象求值有多种用途

  • 获取计算中使用的 JAX 原语序列。此序列将被编译。

  • 计算计算中使用的所有向量和操作的形状和类型。

例如,具有 3 个元素的向量的抽象可以是 ShapedArray(float32[3]),或者 ConcreteArray([1., 2., 3.])。在后一种情况下,JAX 使用包装为抽象值的实际具体值。

from jax import core

@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments

  Args:
    xs, ys, zs: Abstractions of the arguments.

  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now, register the abstract evaluation with JAX:
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

如果您重新尝试应用 jit,您可以检查抽象求值的过程,但您将收到另一个关于缺少实际 XLA 编译规则的错误

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>

Found expected exception:
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_1105/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

XLA 编译规则#

JAX 编译的工作原理是将每个原语编译成 XLA 操作图。

这是向 JAX 添加新功能的最大障碍,因为 XLA 操作集是有限的,并且 JAX 已经为它们中的大多数预定义了原语。但是,XLA 包括一个 CustomCall 操作,可用于封装使用 C++ 定义的任意功能。

from jax._src.lib.mlir.dialects import hlo

@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# Now, register the lowering rule with JAX.
# For GPU, refer to the https://jax.net.cn/en/latest/Custom_Operation_for_GPUs.html
from jax.interpreters import mlir

mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>

您现在将成功应用 jax.jit。请注意,下面 JAX 首先对函数进行抽象求值,这将触发 multiply_add_abstract_eval 函数,然后编译它遇到的原语集,包括 multiply_add。此时,JAX 调用 multiply_add_lowering

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec3deac0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec2041d0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec204210>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ecada790>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ecadbf10>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204d642280>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/1570919344.py":1:0 to 0:0) at callsite("<module>"("/tmp/ipykernel_1105/1570919344.py":1:0 to 0:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3519:0 to 0:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3336:0 to 0:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <lambda> at 0x7370ec976a20, file "/tmp/ipykernel_1105/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1105/1570919344.py":1:0 to 0:0)), (<code object <module> at 0x7370ec975420, file "/tmp/ipykernel_1105/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1105/1570919344.py":1:0 to 0:0)), (<code object run_code at 0x7370fe891a50, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3543>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0)), (<code object run_ast_nodes at 0x7370fe8918f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3420>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3519:0 to 0:0)), (<code object run_cell_async at 0x7370fe891580, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3185>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3336:0 to 0:0)), (<code object _pseudo_sync_runner at 0x7370fe758190, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/1570919344.py': '/tmp/ipykernel_1105/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec200730>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec3d4bf0>]

下面是 jit 的另一个用法,您仅针对第一个参数进行编译。请注意,square_add_prim 的第二个参数是具体的,这导致 multiply_add_abstract_eval 中的第三个参数是 ConcreteArray。请注意,multiply_add_abstract_eval 可以与 ShapedArrayConcreteArray 一起使用。

assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec3df1c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec204db0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec204c60>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec2010b0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec201150>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204d6426e0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/4165789807.py":1:0 to 0:0) at callsite("<module>"("/tmp/ipykernel_1105/4165789807.py":1:0 to 0:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3519:0 to 0:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3336:0 to 0:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <lambda> at 0x7370ec976550, file "/tmp/ipykernel_1105/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1105/4165789807.py":1:0 to 0:0)), (<code object <module> at 0x7370ec9764a0, file "/tmp/ipykernel_1105/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1105/4165789807.py":1:0 to 0:0)), (<code object run_code at 0x7370fe891a50, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3543>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0)), (<code object run_ast_nodes at 0x7370fe8918f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3420>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3519:0 to 0:0)), (<code object run_cell_async at 0x7370fe891580, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3185>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3336:0 to 0:0)), (<code object _pseudo_sync_runner at 0x7370fe758190, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/4165789807.py': '/tmp/ipykernel_1105/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec201a50>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec210570>]

前向微分#

JAX 以雅可比向量积 (JVP) 的形式实现前向微分(您可以在 高级自动微分 中了解更多信息)。

如果您尝试计算 jvp 函数,您将收到错误,因为您尚未告诉 JAX 如何区分 multiply_add 原语。

# The second argument is set to `(2., 10.)` values where you
# evaluate the Jacobian, and the third argument `(1., 1.)`
# contains the values of the tangents for the arguments.
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1105/459539105.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1790, in jvp
    return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad

@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

  Given values of the arguments and perturbation of the arguments (tangents), 
  compute the output of the primitive and the perturbation of the output.

  This method must be JAX-traceable. JAX may invoke it with abstract values 
  for the arguments and tangents.

  Args:
    arg_values: A tuple of arguments
    arg_tangents: A tuple with the tangents of the arguments. The tuple has 
      the same length as the arg_values. Some of the tangents may also be the 
      special value `ad.Zero` to specify a zero tangent

  Returns:
     A pair of the primal output and the tangent.
  """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # Now, you have a JAX-traceable computation of the output. 
  # Normally, you can use the multiply add (`ma`) primitive itself to compute the primal output. 
  primal_out = multiply_add_prim(x, y, z)

  _trace("Tangent evaluation:")
  # You must use a JAX-traceable way to compute the tangent. It turns out that 
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which you can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.

  # You do need to deal specially with `Zero`. Here, you just turn it into a 
  # proper tensor of 0s (of the same shape as 'x'). 
  # An alternative would be to check for `Zero` and perform algebraic 
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan  

  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# Register the forward differentiation rule with JAX:
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, 1.0, 1.0)
        call multiply_add_impl(2.0, 1.0, 1.0)
        |<- multiply_add_impl = 3.0
      |<- multiply_add_prim = 3.0
      call multiply_add_prim(1.0, 2.0, 3.0)
        call multiply_add_impl(1.0, 2.0, 3.0)
        |<- multiply_add_impl = 5.0
      |<- multiply_add_prim = 5.0
    |<- multiply_add_value_and_jvp = (14.0, 5.0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>

前向微分的 JIT#

您可以将 jit 应用于前向微分函数

assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec3df140>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec205170>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec2047b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec2020d0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec937250>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204d9227c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":27:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0) at "<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":27:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <lambda> at 0x7370ec93cea0, file "/tmp/ipykernel_1105/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0)), (<code object <module> at 0x7370ec93d420, file "/tmp/ipykernel_1105/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/347789876.py': '/tmp/ipykernel_1105/347789876.py', '/tmp/ipykernel_1105/2145028508.py': '/tmp/ipykernel_1105/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/tmp/ipykernel_1105/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec201e70>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec25cb70>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec3df140>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec205170>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec2047b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec2020d0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec937250>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204d9227c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":27:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0) at "<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5f204d91a310>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0) at "<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":27:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <lambda> at 0x7370ec93cea0, file "/tmp/ipykernel_1105/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0)), (<code object <module> at 0x7370ec93d420, file "/tmp/ipykernel_1105/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/347789876.py': '/tmp/ipykernel_1105/347789876.py', '/tmp/ipykernel_1105/2145028508.py': '/tmp/ipykernel_1105/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/tmp/ipykernel_1105/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec201d20>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec25cc70>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec3df140>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec205170>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec2047b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec2020d0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec937250>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204d9227c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":27:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0) at "<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5f204d91a310>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0) at "<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5f204d701ac0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0) at "<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":27:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <lambda> at 0x7370ec93cea0, file "/tmp/ipykernel_1105/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1105/2145028508.py":2:0 to 0:0)), (<code object <module> at 0x7370ec93d420, file "/tmp/ipykernel_1105/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1105/2145028508.py":1:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/347789876.py': '/tmp/ipykernel_1105/347789876.py', '/tmp/ipykernel_1105/2145028508.py': '/tmp/ipykernel_1105/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/tmp/ipykernel_1105/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec201cc0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec25ccb0>]

请注意,首先,您抽象地评估 multiply_add_value_and_jvp,这反过来又抽象地评估原始评估和切线评估(总共 3 次调用 ma 原语)。然后,您编译原语的 3 个实例。

反向微分#

如果您现在尝试使用反向微分,您会注意到 JAX 首先使用 multiply_add_value_and_jvp 来计算抽象值的前向微分,但随后遇到 NotImplementedError

在计算反向微分时,JAX 首先对前向微分代码 multiply_add_value_and_jvp 执行抽象评估,以获得计算输出切线的原语轨迹。

  • 观察到 JAX 使用用于微分点的具体值和用于切线的抽象值执行此抽象评估。

  • 请注意,JAX 使用特殊的抽象切线值 Zero 作为与 ma 的第三个参数对应的切线。这反映了您不对 square_add_prim 的第二个参数进行微分,该参数流向 multiply_add_prim 的第三个参数。

  • 另请注意,在切线的抽象评估期间,您将值 0.0 作为第三个参数的切线传递。这是因为在 multiply_add_value_and_jvp 的定义中使用了 make_zero 函数。

# This is reverse differentiation w.r.t. the first argument of `square_add_prim`
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>

Found expected exception:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 456, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_1105/2155094905.py", line 3, in <module>
    api.grad(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 437, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

上面的错误是因为 JAX 缺少一个组件,使其无法使用前向微分代码来计算反向微分。

转置#

如前所述,在计算反向微分时,JAX 获取使用前向微分计算切线的原语轨迹。然后,JAX 向后抽象地解释此轨迹,并为每个原语应用转置规则

为了理解正在发生的事情,请考虑函数 f(x, y) = x * y + y 的一个更简单的示例。假设您需要在点 (2., 4.) 处微分。JAX 将从输入 xtyt 的切线生成以下 ft 的 JVP 切线计算

   a = xt * 4.
   b = 2. * yt
   c = a + b
   ft = c + yt

通过构造,切线计算始终在输入切线中是线性的。切线计算中可能出现的唯一非线性运算符是乘法,但其中一个操作数是常数。

JAX 将通过向后处理 JVP 计算来生成反向微分计算。对于切线计算中的每个操作,它使用操作结果的余切来累积操作使用的变量的余切

  # Initialize cotangents of inputs and intermediate variables:
  xct = yct = act = bct = cct = 0.
  # Initialize cotangent of the output:
  fct = 1.
  # Process `ft = c + yt`:
  cct += fct
  yct += fct
  # Process `c = a + b`:
  act += cct
  bct += cct
  # Process `b = 2. * yt`:
  yct += 2. * bct
  # Process `a = xt * 4.`:
  xct += act * 4.

可以验证,此计算产生 xct = 4.yct = 3.,它们是函数 f 的偏导数。

JAX 知道对于可能出现在 JVP 计算中的每个原语如何转置它。从概念上讲,如果原语 p(x, y, z) 在参数 yz 中对于 x 的常数值是线性的,例如,p(x, y, z) = y*cy + z*cz,则原语的转置是

p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)

请注意,p_transpose 接受原语输出的余切和与原语的每个参数对应的值。对于线性参数,转置获得未定义的 _ 值,对于其他参数,它获得实际常数。转置为原语的每个参数返回一个余切值,为常数参数返回 None 值。

特别是

 add_transpose(out_ct, _, _) = (out_ct, out_ct)
 mult_transpose(out_ct, x, _) = (None, x * out_ct)
 mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
  """Evaluates the transpose of a linear primitive.

  This method is only used when computing the backward gradient following 
  `value_and_jvp`, and is only needed for primitives that are used in the JVP 
  calculation for some other primitive. You need a transposition for `multiply_add_prim`, 
  because you have used `multiply_add_prim` in the computation of the `output_tangent` in 
  `multiply_add_value_and_jvp`.

  In this case, multiply_add is not a linear primitive. However, it is used linearly 
  w.r.t. tangents in `multiply_add_value_and_jvp`:
       `output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))`.

  Always one of the first two multiplicative arguments is a constant.

  Args:
      ct: The cotangent of the output of the primitive.
      x, y, z: The values of the arguments. The arguments that are used linearly
        get an ad.UndefinedPrimal value. The other arguments get a constant
        value.

  Returns:
      A tuple with the cotangent of the inputs, with the value None
      corresponding to the constant arguments.
  """
  if not ad.is_undefined_primal(x):
    # This use of multiply_add is with a constant "x".
    assert ad.is_undefined_primal(y)
    ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
    res = None, ct_y, ct
  else:
    # This use of multiply_add is with a constant "y".
    assert ad.is_undefined_primal(x)
    ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
    res = ct_x, None, ct
  return res

ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

现在您可以完成 grad 的运行

assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(1.0, 2.0, 0.0)
    call multiply_add_impl(1.0, 2.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
  call multiply_add_prim(2.0, 1.0, 0.0)
    call multiply_add_impl(2.0, 1.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)

请注意对 multiply_add_transpose 的两次调用。它们对应于在 multiply_add_value_and_jvpoutput_tangent 计算中对 multiply_add_prim 的两次使用。对转置的第一次调用对应于 multiply_add_prim 的最后一次使用:multiply_add_prim(xt, y, ...),其中 y 是常数 2.0

反向微分的 JIT#

请注意,multiply_add_value_and_jvp 的抽象评估仅使用抽象值。同时,在没有 JIT 的情况下,您使用了 ConcreteArray

assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec2a4bc0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec2b0130>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec2b10c0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec201fb0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec203e80>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204dafafb0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<module>"("/tmp/ipykernel_1105/3085343041.py":1:0 to 0:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <module> at 0x7370ec3edf20, file "/tmp/ipykernel_1105/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1105/3085343041.py":1:0 to 0:0)), (<code object run_code at 0x7370fe891a50, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3543>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/347789876.py': '/tmp/ipykernel_1105/347789876.py', '/tmp/ipykernel_1105/3085343041.py': '/tmp/ipykernel_1105/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1105/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec2ad450>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec932a30>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec2a4bc0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec2b0130>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec2b10c0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec201fb0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec203e80>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204dafafb0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<module>"("/tmp/ipykernel_1105/3085343041.py":1:0 to 0:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5f204dac2820>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<module>"("/tmp/ipykernel_1105/3085343041.py":1:0 to 0:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <module> at 0x7370ec3edf20, file "/tmp/ipykernel_1105/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1105/3085343041.py":1:0 to 0:0)), (<code object run_code at 0x7370fe891a50, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3543>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0)), (<code object multiply_add_value_and_jvp at 0x7370ec976600, file "/tmp/ipykernel_1105/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1105/347789876.py":41:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/347789876.py': '/tmp/ipykernel_1105/347789876.py', '/tmp/ipykernel_1105/3085343041.py': '/tmp/ipykernel_1105/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1105/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec2ad2a0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec2720f0>]

批处理#

批处理转换采用逐点计算,并将其转换为向量计算。如果您现在尝试它,您将收到 NotImplementedError

# The arguments are two vectors instead of two scalars.
with expectNotImplementedError():
  api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
                                               np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)

  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1105/1080163607.py", line 3, in <module>
    api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1053, in vmap_f
    out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented

您需要指示 JAX 如何评估原语的批处理版本。在这种特殊情况下,multiply_add_prim 已经对输入向量的任何维度逐点运算,因此批处理版本可以使用相同的 multiply_add_prim 实现。

from jax.interpreters import batching

@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
  """Computes the batched version of the primitive.
  
  This must be a JAX-traceable function.
  
  Since the `multiply_add primitive` already operates point-wise on arbitrary
  dimension tensors, to batch it you can use the primitive itself. This works as
  long as both the inputs have the same dimensions and are batched along the
  same axes. The result is batched along the axis that the inputs are batched.

  Args:
    vector_arg_values: A tuple of two arguments, each being a tensor of matching
      shape.
    batch_axes: The axes that are being batched. See vmap documentation.

  Returns:
    A tuple of the result, and the result axis that was batched. 
  """
  assert batch_axes[0] == batch_axes[1]
  assert batch_axes[0] == batch_axes[2]
  _trace("Using multiply_add to compute the batch:")
  res = multiply_add_prim(*vector_arg_values)
  return res, batch_axes[0]


batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
  np.array([2., 3.]),
  np.array([10., 20.])),
  [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
        call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
        |<- multiply_add_impl = [14. 29.]
      |<- multiply_add_prim = [14. 29.]
    |<- multiply_add_batch = ([14. 29.], 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>

批处理的 JIT#

下面是将 JIT 应用于批处理的示例

assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
                    (np.array([2., 3.]),
                     np.array([10., 20.])),
                    [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[2])
      |<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>
    |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7370ec2a50c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7370ec2b2570>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7370ec2b23d0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7370ec2013b0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7370ee10ecf0>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7370ec2026e0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5f204dc2ee60>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1105/1827752256.py":25:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0) at callsite("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0) at callsite("<module>"("/tmp/ipykernel_1105/1392464762.py":1:0 to 0:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7370ec93e080, file "/tmp/ipykernel_1105/2637569133.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1105/2637569133.py":12:0 to 0:0)), (<code object func_wrapper at 0x7370ecad6970, file "/tmp/ipykernel_1105/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1105/1393342955.py":48:0 to 0:0)), (<code object multiply_add_batch at 0x7370ec3ecbe0, file "/tmp/ipykernel_1105/1827752256.py", line 3>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1105/1827752256.py":25:0 to 0:0)), (<code object square_add_prim at 0x7370ec93ead0, file "/tmp/ipykernel_1105/2637569133.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1105/2637569133.py":17:0 to 0:0)), (<code object <module> at 0x7370ec3ee1e0, file "/tmp/ipykernel_1105/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1105/1392464762.py":1:0 to 0:0)), (<code object run_code at 0x7370fe891a50, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3543>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3579:0 to 0:0))}, canonical_name_cache={'/tmp/ipykernel_1105/2637569133.py': '/tmp/ipykernel_1105/2637569133.py', '/tmp/ipykernel_1105/1393342955.py': '/tmp/ipykernel_1105/1393342955.py', '/tmp/ipykernel_1105/1827752256.py': '/tmp/ipykernel_1105/1827752256.py', '/tmp/ipykernel_1105/1392464762.py': '/tmp/ipykernel_1105/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1105/2637569133.py': True, '/tmp/ipykernel_1105/1393342955.py': True, '/tmp/ipykernel_1105/1827752256.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1105/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7370ec2adab0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7370ec273c30>]