jax.debug.inspect_array_sharding#

jax.debug.inspect_array_sharding(value, *, callback)[source]#

允许在 JIT 编译的函数中检查数组分片。

此函数在提供数组的 Pytree 时,会回调每个数组的分片,并在 pjit 编译的计算中工作,从而可以检查选择的中间分片。

当分片信息可用时,调用 callback 的策略是*尽可能早*。这意味着如果调用 inspect_array_callback 而没有任何转换,则回调将立即发生,因为我们已经准备好了数组及其分片。在 jax.jit 内部,回调将在降低时间发生,这意味着您可以使用 AOT API (jit(f).lower(...)) 触发回调。当在 pjit 内部时,回调在*编译时*发生,因为分片由 XLA 确定。您可以使用 JAX 的 AOT API (pjit(f).lower(...).compile()) 触发回调。在所有情况下,都将通过运行该函数来触发回调,因为运行该函数需要先降低和编译它。但是,一旦函数被编译和缓存,回调将不再发生。

此函数是实验性的,其行为将来可能会发生变化。

参数:
  • value – JAX 数组的 Pytree。

  • callback (Callable[[Sharding], None]) – 可调用对象,它接受一个 Sharding 并且不返回值。

在以下示例中,我们打印出 pjit 编译的计算中的中间值的分片

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh, PartitionSpec
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> def f_(x):
...   x = jnp.sin(x)
...   jax.debug.inspect_array_sharding(x, callback=print)
...   return jnp.square(x)
>>> f = pjit(f_, in_shardings=PartitionSpec('dev'),
...          out_shardings=PartitionSpec('dev'))
>>> with Mesh(jax.devices(), ('dev',)):
...   f.lower(x).compile()  
...
NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))