异步调度#
JAX 使用异步调度来隐藏 Python 开销。考虑以下程序
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.
Array([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
当执行诸如 jnp.dot(x, x)
之类的操作时,JAX 不会等待操作完成,然后再将控制权返回给 Python 程序。相反,JAX 返回一个 jax.Array
值,这是一个 future,即将在加速器设备上产生的未来值,但不一定立即可用。 我们可以检查 jax.Array
的形状或类型,而无需等待生成它的计算完成,我们甚至可以将其传递给另一个 JAX 计算,就像我们在此处的加法运算中所做的那样。 只有当我们真正从主机检查数组的值时,例如通过打印它或将其转换为普通的旧 numpy.ndarray
时,JAX 才会强制 Python 代码等待计算完成。
异步调度非常有用,因为它允许 Python 代码“超前”于加速器设备运行,使 Python 代码脱离关键路径。 只要 Python 代码在设备上排队工作的速度快于设备执行速度,并且只要 Python 代码实际上不需要在主机上检查计算的输出,那么 Python 程序就可以排队任意数量的工作,并避免让加速器等待。
异步调度对于微基准测试有一个稍微令人惊讶的结果。
>>> %time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
269µs 对于 CPU 上 1000x1000 矩阵乘法来说是一个非常短的时间! 然而,事实证明异步调度会误导我们,我们没有计时矩阵乘法的执行时间,而只是调度工作的时间。 为了衡量操作的真实成本,我们必须在主机上读取该值(例如,将其转换为普通的旧主机端 numpy 数组),或者使用 jax.Array
值上的 block_until_ready()
方法来等待生成它的计算完成。
>>> %time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
在不将结果传输回 Python 的情况下进行阻塞通常更快,并且通常是在编写计算时间的微基准测试时的最佳选择。