Ref
:用于数据管道和内存控制的可变数组#
JAX Array
是不可变的,代表数学值。不可变性可以使代码更易于推理,并且对于优化编译、并行化、重构和自动微分等转换很有用。
但不可变性也有局限性
表达能力 — 传递中间数据或维护状态,例如用于规范化统计数据或指标,可能会显得笨重;
性能 — 更难推理性能,例如内存生命周期和就地更新。
Ref
s 可以提供帮助!它们代表了可以就地读取和写入的可变数组。这些数组引用与 JAX 转换兼容,例如 jax.jit
和 jax.grad
。
import jax
import jax.numpy as jnp
x_ref = jax.new_ref(jnp.zeros(3)) # new array ref, with initial value [0., 0., 0.]
@jax.jit
def f():
x_ref[1] += 1. # indexed add-update
print(x_ref) # Ref([0., 0., 0.])
f()
f()
print(x_ref) # Ref([0., 2., 0.])
Ref([0., 0., 0.], dtype=float32)
Ref([0., 2., 0.], dtype=float32)
索引语法遵循 NumPy。对于一个名为 x_ref
的 Ref
,我们可以通过编写 x_ref[...]
将其整个值读取到一个 Array
中,并使用 x_ref[...] = A
为某个 Array
值表达式 A
来写入其整个值。
def g(x):
x_ref = jax.new_ref(0.)
x_ref[...] = jnp.sin(x)
return x_ref[...]
print(jax.grad(g)(1.0)) # 0.54
0.5403023
Ref
是与 Array
不同的类型,并且具有一些重要的约束和限制。特别是,索引读写几乎是您可以使用 Ref
做的唯一事情。引用不能在期望 Array
的地方传递。
x_ref = jax.new_ref(1.0)
try:
jnp.sin(x_ref) # error! can't do math on refs
except Exception as e:
print(e)
Attempting to pass a Ref Ref{float32[]} to a primitive: sin -- did you forget to unpack ([...]) the ref?
要进行数学运算,您需要先读取 ref 的值,例如 jnp.sin(x_ref[...])
。
那么您 *可以* 用 Ref
做什么?继续阅读以了解详细信息和一些有用的示例。
API#
如果您曾经使用过 Pallas,那么 Ref
应该看起来很熟悉。一个很大的区别是您可以自己直接使用 jax.new_ref
创建新的 Ref
。
from jax import Array, Ref
def array_ref(init_val: Array) -> Ref:
"""Introduce a new reference with given initial value."""
jax.freeze
是它的反义词,它使给定的 ref 无效(因此之后访问它将是一个错误),并产生其最终值。
def freeze(ref: Ref) -> Array:
"""Invalidate given reference and produce its final value."""
在创建和销毁它们之间,您可以对 refs 执行索引读写。您可以使用函数 jax.ref.get
和 jax.ref.swap
进行读取和写入,但通常您只会使用 NumPy 风格的数组索引语法。
import types
Index = int | slice | Array | types.EllipsisType
Indexer = Index | tuple[Index, ...]
def get(ref: Ref, idx: Indexer) -> Array:
"""Returns `ref[idx]` for NumPy-style indexer `idx`."""
def swap(ref: Ref, idx: Indexer, val: Array) -> Array:
"""Performs `newval, ref[idx] = ref[idx], val` and returns `newval`."""
在这里,Indexer
可以是任何 NumPy 索引表达式。
x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4))
# int indexing
row = x_ref[0]
x_ref[1] = row
# tuple indexing
val = x_ref[1, 2]
x_ref[2, 3] = val
# slice indexing
col = x_ref[:, 1]
x_ref[0, :3] = col
# advanced int array indexing
vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])]
x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals
与 Array
s 一样,索引主要遵循 NumPy 的行为,除了越界索引,它 以 JAX Array
s 的常规方式行为。
纯函数与纯函数#
一个以 ref 作为参数(显式或通过词法闭包)的函数被认为是*纯函数*。例如:
# takes ref as an argument => impure
@jax.jit
def impure1(x_ref, y_ref):
x_ref[...] = y_ref[...]
# closes over ref => impure
y_ref = jax.new_ref(0)
@jax.jit
def impure2(x):
y_ref[...] = x
如果一个函数仅在内部使用 refs,它仍然被认为是*纯函数*。纯洁性取决于调用者。例如:
# internal refs => still pure
@jax.jit
def pure1(x):
ref = jax.new_ref(x)
ref[...] = ref[...] + ref[...]
return ref[...]
纯函数,即使是那些在内部使用 refs 的函数,也是熟悉的:例如,它们可以像往常一样与 jax.grad
、jax.vmap
、jax.shard_map
等转换一起工作。
纯函数按 Python 程序顺序进行排序。
限制#
Ref
s 是二等公民,这意味着对其使用有限制。
不能从
jit
装饰的函数或高阶原语(如jax.lax.scan
、jax.lax.while_loop
或jax.lax.cond
)的**主体中返回 refs。不能将 ref 多次作为参数传递给
jit
装饰的函数或高阶原语。只能在创建作用域中
freeze
。没有高阶 ref(ref-to-ref)。
例如,这些是错误的:
x_ref = jax.new_ref(0.)
# can't return refs
@jax.jit
def err1(x_ref):
x_ref[...] = 5.
return x_ref # error!
try:
err1(x_ref)
except Exception as e:
print(e)
# can't pass a ref as an argument more than once
@jax.jit
def err2(x_ref, y_ref):
...
try:
err2(x_ref, x_ref) # error!
except Exception as e:
print(e)
# can't pass and close over the same ref
@jax.jit
def err3(y_ref):
y_ref[...] = x_ref[...]
try:
err3(x_ref) # error!
except Exception as e:
print(e)
# can only freeze in creation scope
@jax.jit
def err4(x_ref):
jax.freeze(x_ref)
try:
err4(x_ref) # error!
except Exception as e:
print(e)
function err1 at /tmp/ipykernel_1367/3915325362.py:4 traced for jit returned a mutable array reference of type Ref{float32[]} at output tree path result, but mutable array references cannot be returned.
The returned mutable array was passed in as the argument x_ref.
only one reference to a mutable array may be passed as an argument to a function, but when tracing err2 at /tmp/ipykernel_1367/3915325362.py:14 for jit the mutable array reference of type Ref{float32[]} appeared at both x_ref and y_ref.
when tracing err3 at /tmp/ipykernel_1367/3915325362.py:23 for jit, a mutable array reference of type Ref{float32[]} was both closed over and passed as the argument y_ref
list index out of range
这些限制是为了排除别名,即两个 ref 可能指向相同的可变内存,这使得程序更难推理和转换。较弱的限制也足够了,因此随着我们改进 JAX 验证没有别名的能力,其中一些限制可能会被解除。
此外,还存在源于未定义语义的限制,例如在存在并行性或重构的情况下。
不能
vmap
或shard_map
闭包了 refs 的函数。不能将
jax.remat
/jax.checkpoint
应用于纯函数。
例如,以下是您可以使用和不能使用 vmap
配合纯函数的方法:
# vmap over ref args is okay
def dist(x, y, out_ref):
assert x.ndim == y.ndim == 1
assert out_ref.ndim == 0
out_ref[...] = jnp.sum((x - y) ** 2)
vecs = jnp.arange(12.).reshape(3, 4)
out_ref = jax.new_ref(jnp.zeros((3, 3)))
jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref) # ok!
print(out_ref)
Ref([[ 0., 64., 256.],
[ 64., 0., 64.],
[256., 64., 0.]], dtype=float32)
# vmap with a closed-over ref is not
x_ref = jax.new_ref(0.)
def err5(x):
x_ref[...] = x
try:
jax.vmap(err5)(jnp.arange(3.)) # error!
except Exception as e:
print(e)
performing a set/swap operation with vmapped value on an unbatched array reference of type Ref{float32[]}. Move the array reference to be an argument to the vmapped function?
后者是一个错误,因为不清楚在运行 jax.vmap(err5)
后 x_ref
的值应该是多少。
Ref
s 和自动微分#
自动微分可以像以前一样应用于纯函数,即使它们在内部使用数组引用。例如:
@jax.jit
def pure2(x):
ref = jax.new_ref(x)
ref[...] = ref[...] + ref[...]
return ref[...]
print(jax.grad(pure1)(3.0)) # 2.0
2.0
如果数组引用仅用于管道传输且不参与微分,自动微分也可以应用于接受数组引用的函数。
# error
def err6(x, some_plumbing_ref):
y = x + x
some_plumbing_ref[...] += y
return y
# fine
def foo(x, some_plumbing_ref):
y = x + x
some_plumbing_ref[...] += jax.lax.stop_gradient(y)
return y
您可以将管道 refs 与 custom_vjp
结合使用,将数据从微分函数的反向传播中提取出来。
# First, define the helper `stash_grads`:
@jax.custom_vjp
def stash_grads(grads_ref, x):
return x
def stash_grads_fwd(grads_ref, x):
return x, grads_ref
def stash_grads_bwd(grads_ref, g):
grads_ref[...] = g
return None, g
stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)
# Now, use `stash_grads` to stash intermediate gradients:
def f(x, grads_ref):
x = jnp.sin(x)
x = stash_grads(grads_ref, x)
return x
grads_ref = jax.new_ref(0.)
f(1., grads_ref)
print(grads_ref)
Ref(0., dtype=float32, weak_type=True)
请注意,stash_grads_fwd
在此处返回一个 Ref
。这是 custom_vjp
前向规则的一个特殊允许:它实际上是用于指示哪些 ref 参数应由前向和后向规则共享的语法。因此,由前向规则返回的任何 refs 都必须是该前向规则的参数。
Ref
s 和性能#
在顶层,调用 jit
装饰的函数时,Ref
s 消除了捐赠的需要,因为它们实际上总是被捐赠的。
@jax.jit
def sin_inplace(x_ref):
x_ref[...] = jnp.sin(x_ref[...])
x_ref = jax.new_ref(jnp.arange(3.))
print(x_ref.unsafe_buffer_pointer(), x_ref)
sin_inplace(x_ref)
print(x_ref.unsafe_buffer_pointer(), x_ref)
102660576838912 Ref([0., 1., 2.], dtype=float32)
102660576838912 Ref([0. , 0.84147096, 0.9092974 ], dtype=float32)
在这里,sin_inplace
以就地方式操作,更新 x_ref
的底层缓冲区,使其地址保持不变。
在 jit
下,您应该期望数组引用指向固定的缓冲区地址,并且索引更新以就地方式执行。
临时说明:目前,从 Python 分派到接受 Ref
输入的纯 jit
编译函数比分派到纯 jit
编译函数要慢,因为它采用了一条不太优化的路径。
foreach
,一种编写 scan
的新方法#
您可能知道,jax.lax.scan
是一个具有内置的扫描输入和输出固定访问模式的循环构造。访问模式是为了自动微分的原因而内置的:如果我们而是直接切片不可变输入,反向模式自动微分将最终创建一个独热梯度并对它们求和,这可能在渐近效率上很低。请参阅 Dex 论文的第 5.3.3 节。
但是,读取 Ref
s 的切片没有这个效率问题:当我们应用反向模式自动微分时,我们总是生成就地累加操作。因此,我们不再需要受 scan
的固定访问模式的限制。我们可以编写更灵活的循环,例如具有非顺序访问的循环。
此外,可用的变异允许一些语法技巧,例如在此 foreach
装饰器的示例中:
import jax
import jax.numpy as jnp
from jax.lax import scan
def foreach(*args):
def decorator(body):
return scan(lambda _, elts: (None, body(*elts)), None, args)[1]
return decorator
r = jax.new_ref(0)
xs = jnp.arange(10)
@foreach(xs)
def ys(x):
r[...] += x
return x * 2
print(r) # Ref(45, dtype=int32)
print(ys) # [ 0 2 4 6 8 10 12 14 16 18]
Ref(45, dtype=int32)
[ 0 2 4 6 8 10 12 14 16 18]
在这里,循环立即运行,就地更新 r
并将 ys
绑定为映射结果。