jax.debug.inspect_array_sharding#
- jax.debug.inspect_array_sharding(value, *, callback)[source]#
允许在 JIT 编译的函数内部检查数组分片。
此函数在接收到数组的 Pytree 时,会回调每个数组的分片,并在
pjit
编译的计算中工作,从而可以检查选择的中间分片。当分片信息可用时,调用
callback
的策略是尽可能早。 这意味着,如果在没有任何转换的情况下调用inspect_array_callback
,则回调将立即发生,因为我们已准备好数组及其分片。 在jax.jit
内部,回调将在降低时间发生,这意味着您可以使用 AOT API (jit(f).lower(...)
) 触发回调。 当在pjit
内部时,回调发生在编译时,因为分片由 XLA 决定。 您可以使用 JAX 的 AOT API (pjit(f).lower(...).compile()
) 触发回调。 在所有情况下,都将通过运行函数来触发回调,因为运行函数需要先降低和编译它。 但是,一旦函数被编译和缓存,回调将不再发生。此函数是实验性的,其行为在未来可能会发生变化。
- 参数:
value – JAX 数组的 Pytree。
callback (Callable[[Sharding], None]) – 一个可调用对象,它接受一个
Sharding
并且不返回值。
在以下示例中,我们打印出
pjit
编译的计算中一个中间值的分片>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh, PartitionSpec >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> def f_(x): ... x = jnp.sin(x) ... jax.debug.inspect_array_sharding(x, callback=print) ... return jnp.square(x) >>> f = pjit(f_, in_shardings=PartitionSpec('dev'), ... out_shardings=PartitionSpec('dev')) >>> with Mesh(jax.devices(), ('dev',)): ... f.lower(x).compile() ... NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))