Ref:用于数据管道和内存控制的可变数组#

JAX Array 是不可变的,代表数学值。不可变性可以使代码更易于推理,并且对于优化编译、并行化、重构和自动微分等转换很有用。

但不可变性也有局限性

  • 表达能力 — 传递中间数据或维护状态,例如用于规范化统计数据或指标,可能会显得笨重;

  • 性能 — 更难推理性能,例如内存生命周期和就地更新。

Refs 可以提供帮助!它们代表了可以就地读取和写入的可变数组。这些数组引用与 JAX 转换兼容,例如 jax.jitjax.grad

import jax
import jax.numpy as jnp

x_ref = jax.new_ref(jnp.zeros(3))  # new array ref, with initial value [0., 0., 0.]

@jax.jit
def f():
  x_ref[1] += 1.  # indexed add-update

print(x_ref)  # Ref([0., 0., 0.])
f()
f()
print(x_ref)  # Ref([0., 2., 0.])
Ref([0., 0., 0.], dtype=float32)
Ref([0., 2., 0.], dtype=float32)

索引语法遵循 NumPy。对于一个名为 x_refRef,我们可以通过编写 x_ref[...] 将其整个值读取到一个 Array 中,并使用 x_ref[...] = A 为某个 Array 值表达式 A 来写入其整个值。

def g(x):
  x_ref = jax.new_ref(0.)
  x_ref[...] = jnp.sin(x)
  return x_ref[...]

print(jax.grad(g)(1.0))  # 0.54
0.5403023

Ref 是与 Array 不同的类型,并且具有一些重要的约束和限制。特别是,索引读写几乎是您可以使用 Ref 做的唯一事情。引用不能在期望 Array 的地方传递。

x_ref = jax.new_ref(1.0)
try:
  jnp.sin(x_ref)  # error! can't do math on refs
except Exception as e:
  print(e)
Attempting to pass a Ref Ref{float32[]} to a primitive: sin -- did you forget to unpack ([...]) the ref?

要进行数学运算,您需要先读取 ref 的值,例如 jnp.sin(x_ref[...])

那么您 *可以* 用 Ref 做什么?继续阅读以了解详细信息和一些有用的示例。

API#

如果您曾经使用过 Pallas,那么 Ref 应该看起来很熟悉。一个很大的区别是您可以自己直接使用 jax.new_ref 创建新的 Ref

from jax import Array, Ref

def array_ref(init_val: Array) -> Ref:
  """Introduce a new reference with given initial value."""

jax.freeze 是它的反义词,它使给定的 ref 无效(因此之后访问它将是一个错误),并产生其最终值。

def freeze(ref: Ref) -> Array:
  """Invalidate given reference and produce its final value."""

在创建和销毁它们之间,您可以对 refs 执行索引读写。您可以使用函数 jax.ref.getjax.ref.swap 进行读取和写入,但通常您只会使用 NumPy 风格的数组索引语法。

import types
Index = int | slice | Array | types.EllipsisType
Indexer = Index | tuple[Index, ...]

def get(ref: Ref, idx: Indexer) -> Array:
  """Returns `ref[idx]` for NumPy-style indexer `idx`."""

def swap(ref: Ref, idx: Indexer, val: Array) -> Array:
  """Performs `newval, ref[idx] = ref[idx], val` and returns `newval`."""

在这里,Indexer 可以是任何 NumPy 索引表达式。

x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4))

# int indexing
row = x_ref[0]
x_ref[1] = row

# tuple indexing
val = x_ref[1, 2]
x_ref[2, 3] = val

# slice indexing
col = x_ref[:, 1]
x_ref[0, :3] = col

# advanced int array indexing
vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])]
x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals

Arrays 一样,索引主要遵循 NumPy 的行为,除了越界索引,它 以 JAX Arrays 的常规方式行为

纯函数与纯函数#

一个以 ref 作为参数(显式或通过词法闭包)的函数被认为是*纯函数*。例如:

# takes ref as an argument => impure
@jax.jit
def impure1(x_ref, y_ref):
  x_ref[...] = y_ref[...]

# closes over ref => impure
y_ref = jax.new_ref(0)

@jax.jit
def impure2(x):
  y_ref[...] = x

如果一个函数仅在内部使用 refs,它仍然被认为是*纯函数*。纯洁性取决于调用者。例如:

# internal refs => still pure
@jax.jit
def pure1(x):
  ref = jax.new_ref(x)
  ref[...] = ref[...] + ref[...]
  return ref[...]

纯函数,即使是那些在内部使用 refs 的函数,也是熟悉的:例如,它们可以像往常一样与 jax.gradjax.vmapjax.shard_map 等转换一起工作。

纯函数按 Python 程序顺序进行排序。

限制#

Refs 是二等公民,这意味着对其使用有限制。

  • 不能从 jit 装饰的函数或高阶原语(如 jax.lax.scanjax.lax.while_loopjax.lax.cond)的**主体中返回 refs。

  • 不能将 ref 多次作为参数传递jit 装饰的函数或高阶原语。

  • 只能在创建作用域中 freeze

  • 没有高阶 ref(ref-to-ref)。

例如,这些是错误的:

x_ref = jax.new_ref(0.)

# can't return refs
@jax.jit
def err1(x_ref):
  x_ref[...] = 5.
  return x_ref  # error!
try:
  err1(x_ref)
except Exception as e:
  print(e)

# can't pass a ref as an argument more than once
@jax.jit
def err2(x_ref, y_ref):
  ...
try:
  err2(x_ref, x_ref)  # error!
except Exception as e:
  print(e)

# can't pass and close over the same ref
@jax.jit
def err3(y_ref):
  y_ref[...] = x_ref[...]
try:
  err3(x_ref)  # error!
except Exception as e:
  print(e)

# can only freeze in creation scope
@jax.jit
def err4(x_ref):
  jax.freeze(x_ref)
try:
  err4(x_ref)  # error!
except Exception as e:
  print(e)
function err1 at /tmp/ipykernel_1367/3915325362.py:4 traced for jit returned a mutable array reference of type Ref{float32[]} at output tree path result, but mutable array references cannot be returned.

The returned mutable array was passed in as the argument x_ref.
only one reference to a mutable array may be passed as an argument to a function, but when tracing err2 at /tmp/ipykernel_1367/3915325362.py:14 for jit the mutable array reference of type Ref{float32[]} appeared at both x_ref and y_ref.
when tracing err3 at /tmp/ipykernel_1367/3915325362.py:23 for jit, a mutable array reference of type Ref{float32[]} was both closed over and passed as the argument y_ref
list index out of range

这些限制是为了排除别名,即两个 ref 可能指向相同的可变内存,这使得程序更难推理和转换。较弱的限制也足够了,因此随着我们改进 JAX 验证没有别名的能力,其中一些限制可能会被解除。

此外,还存在源于未定义语义的限制,例如在存在并行性或重构的情况下。

  • 不能 vmapshard_map 闭包了 refs 的函数。

  • 不能将 jax.remat/jax.checkpoint 应用于纯函数。

例如,以下是您可以使用和不能使用 vmap 配合纯函数的方法:

# vmap over ref args is okay
def dist(x, y, out_ref):
  assert x.ndim == y.ndim == 1
  assert out_ref.ndim == 0
  out_ref[...] = jnp.sum((x - y) ** 2)

vecs = jnp.arange(12.).reshape(3, 4)
out_ref = jax.new_ref(jnp.zeros((3, 3)))
jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref)  # ok!
print(out_ref)
Ref([[  0.,  64., 256.],
       [ 64.,   0.,  64.],
       [256.,  64.,   0.]], dtype=float32)
# vmap with a closed-over ref is not
x_ref = jax.new_ref(0.)

def err5(x):
  x_ref[...] = x

try:
  jax.vmap(err5)(jnp.arange(3.))  # error!
except Exception as e:
  print(e)
performing a set/swap operation with vmapped value on an unbatched array reference of type Ref{float32[]}. Move the array reference to be an argument to the vmapped function?

后者是一个错误,因为不清楚在运行 jax.vmap(err5)x_ref 的值应该是多少。

Refs 和自动微分#

自动微分可以像以前一样应用于纯函数,即使它们在内部使用数组引用。例如:

@jax.jit
def pure2(x):
  ref = jax.new_ref(x)
  ref[...] = ref[...] + ref[...]
  return ref[...]

print(jax.grad(pure1)(3.0))  # 2.0
2.0

如果数组引用仅用于管道传输且不参与微分,自动微分也可以应用于接受数组引用的函数。

# error
def err6(x, some_plumbing_ref):
  y = x + x
  some_plumbing_ref[...] += y
  return y

# fine
def foo(x, some_plumbing_ref):
  y = x + x
  some_plumbing_ref[...] += jax.lax.stop_gradient(y)
  return y

您可以将管道 refs 与 custom_vjp 结合使用,将数据从微分函数的反向传播中提取出来。

# First, define the helper `stash_grads`:

@jax.custom_vjp
def stash_grads(grads_ref, x):
  return x

def stash_grads_fwd(grads_ref, x):
  return x, grads_ref

def stash_grads_bwd(grads_ref, g):
  grads_ref[...] = g
  return None, g

stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)
# Now, use `stash_grads` to stash intermediate gradients:

def f(x, grads_ref):
  x = jnp.sin(x)
  x = stash_grads(grads_ref, x)
  return x

grads_ref = jax.new_ref(0.)
f(1., grads_ref)
print(grads_ref)
Ref(0., dtype=float32, weak_type=True)

请注意,stash_grads_fwd 在此处返回一个 Ref。这是 custom_vjp 前向规则的一个特殊允许:它实际上是用于指示哪些 ref 参数应由前向和后向规则共享的语法。因此,由前向规则返回的任何 refs 都必须是该前向规则的参数。

Refs 和性能#

在顶层,调用 jit 装饰的函数时,Refs 消除了捐赠的需要,因为它们实际上总是被捐赠的。

@jax.jit
def sin_inplace(x_ref):
  x_ref[...] = jnp.sin(x_ref[...])

x_ref = jax.new_ref(jnp.arange(3.))
print(x_ref.unsafe_buffer_pointer(), x_ref)
sin_inplace(x_ref)
print(x_ref.unsafe_buffer_pointer(), x_ref)
102660576838912 Ref([0., 1., 2.], dtype=float32)
102660576838912 Ref([0.        , 0.84147096, 0.9092974 ], dtype=float32)

在这里,sin_inplace 以就地方式操作,更新 x_ref 的底层缓冲区,使其地址保持不变。

jit 下,您应该期望数组引用指向固定的缓冲区地址,并且索引更新以就地方式执行。

临时说明:目前,从 Python 分派到接受 Ref 输入的纯 jit 编译函数比分派到纯 jit 编译函数要慢,因为它采用了一条不太优化的路径。

foreach,一种编写 scan 的新方法#

您可能知道,jax.lax.scan 是一个具有内置的扫描输入和输出固定访问模式的循环构造。访问模式是为了自动微分的原因而内置的:如果我们而是直接切片不可变输入,反向模式自动微分将最终创建一个独热梯度并对它们求和,这可能在渐近效率上很低。请参阅 Dex 论文的第 5.3.3 节

但是,读取 Refs 的切片没有这个效率问题:当我们应用反向模式自动微分时,我们总是生成就地累加操作。因此,我们不再需要受 scan 的固定访问模式的限制。我们可以编写更灵活的循环,例如具有非顺序访问的循环。

此外,可用的变异允许一些语法技巧,例如在此 foreach 装饰器的示例中:

import jax
import jax.numpy as jnp
from jax.lax import scan

def foreach(*args):
  def decorator(body):
    return scan(lambda _, elts: (None, body(*elts)), None, args)[1]
  return decorator
r = jax.new_ref(0)
xs = jnp.arange(10)

@foreach(xs)
def ys(x):
  r[...] += x
  return x * 2

print(r)   # Ref(45, dtype=int32)
print(ys)  # [ 0  2  4  6  8 10 12 14 16 18]
Ref(45, dtype=int32)
[ 0  2  4  6  8 10 12 14 16 18]

在这里,循环立即运行,就地更新 r 并将 ys 绑定为映射结果。