自动向量化#
在上一节中,我们讨论了通过 jax.jit()
函数进行的 JIT 编译。本笔记将讨论 JAX 的另一种变换:通过 jax.vmap()
进行的向量化。
手动向量化#
考虑以下计算两个一维向量卷积的简单代码
import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
Array([11., 20., 29.], dtype=float32)
假设我们想将此函数应用于一批权重 w
到一批向量 x
。
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
最原始的选择是简单地在 Python 中循环遍历批次
def manually_batched_convolve(xs, ws):
output = []
for i in range(xs.shape[0]):
output.append(convolve(xs[i], ws[i]))
return jnp.stack(output)
manually_batched_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
这会产生正确的结果,但效率不高。
为了高效地进行批处理计算,您通常需要手动重写函数,以确保其以向量化形式完成。这实现起来并不特别困难,但确实涉及更改函数处理索引、轴和输入其他部分的方式。
例如,我们可以手动重写 convolve()
以支持批次维度上的向量化计算,如下所示
def manually_vectorized_convolve(xs, ws):
output = []
for i in range(1, xs.shape[-1] -1):
output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
随着函数复杂度的增加,这种重新实现可能会变得混乱且容易出错;幸运的是,JAX 提供了另一种方式。
自动向量化#
在 JAX 中,jax.vmap()
变换旨在自动生成函数的向量化实现
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
它通过类似于 jax.jit()
的方式追踪函数,并自动在每个输入的开头添加批次轴来实现这一点。
如果批次维度不是第一个,您可以使用 in_axes
和 out_axes
参数来指定输入和输出中批次维度的位置。如果所有输入和输出的批次轴相同,则这些参数可以是整数,否则可以是列表。
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve_v2(xst, wst)
Array([[11., 11.],
[20., 20.],
[29., 29.]], dtype=float32)
jax.vmap()
也支持仅对其中一个参数进行批处理的情况:例如,如果您想将一组权重 w
与一批向量 x
进行卷积;在这种情况下,in_axes
参数可以设置为 None
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
结合变换#
与所有 JAX 变换一样,jax.jit()
和 jax.vmap()
旨在可组合,这意味着您可以将一个 vmap 过的函数用 jit
包装,或者将一个 jit 过的函数用 vmap
包装,并且一切都会正常工作
jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)