jax.debug.inspect_array_sharding#
- jax.debug.inspect_array_sharding(value, *, callback)[源代码]#
启用对 JIT 编译函数内部数组分片(sharding)的检查。
此函数在接收到数组的 Pytree 后,会调用提供的回调函数并传入每个数组的分片信息。它可以在
jax.jit编译的计算中工作,从而允许检查选择的中间分片。调用
callback的时机策略是尽可能早地在分片信息可用时。这意味着如果inspect_array_callback在没有任何转换的情况下被调用,回调将立即发生,因为我们已准备好数组及其分片信息。在jax.jit内部,回调将在降低(lowering)时发生,这意味着您可以使用 AOT API(jit(f).lower(...))来触发回调。在jax.jit内部,回调在编译时发生,因为分片信息由 XLA 确定。您可以通过使用 JAX 的 AOT API(jax.jit(f).lower(...).compile())来触发回调。在所有情况下,回调都将通过运行函数来触发,因为运行函数首先涉及其降低和编译。但是,一旦函数被编译并缓存,回调将不再发生。此函数是实验性的,其行为将来可能会发生变化。
- 参数:
value – JAX 数组的 Pytree。
callback (Callable[[Sharding], None]) – 一个可调用对象,它接收一个
Sharding对象,并且不返回任何值。
在以下示例中,我们打印了
jax.jit编译计算中中间值的分片信息。>>> import jax >>> import jax.numpy as jnp >>> from jax.sharding import Mesh, PartitionSpec >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> def f_(x): ... x = jnp.sin(x) ... jax.debug.inspect_array_sharding(x, callback=print) ... return jnp.square(x) >>> f = jax.jit(f_, in_shardings=PartitionSpec('dev'), ... out_shardings=PartitionSpec('dev')) >>> with jax.set_mesh(Mesh(jax.devices(), ('dev',))): ... f.lower(x).compile() ... NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))