Rank promotion warning#
NumPy broadcasting rules 允许将参数从一个秩(数组轴的数量)自动提升到另一个秩。当这种行为符合预期时,它可能非常方便,但也可能导致意外的错误,因为静默的秩提升会掩盖潜在的形状错误。
这是秩提升的一个例子
>>> from jax import numpy as jnp
>>> x = jnp.arange(12).reshape(4, 3)
>>> y = jnp.array([0, 1, 0])
>>> x + y
Array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]], dtype=int32)
为了避免潜在的意外情况,jax.numpy 是可配置的,因此需要秩提升的表达式可以产生警告、错误,或者像常规 NumPy 一样被允许。配置选项名为 jax_numpy_rank_promotion,它可以接受字符串值 allow、warn 和 raise。默认设置是 allow,它允许秩提升而不发出警告或错误。raise 设置在发生秩提升时会引发错误,而 warn 在第一次出现秩提升时会发出警告。
可以使用 jax.numpy_rank_promotion() 上下文管理器在本地启用或禁用秩提升
with jax.numpy_rank_promotion("warn"):
z = x + y
此配置也可以通过多种方式全局设置。一种方法是在代码中使用 jax.config
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
您还可以使用环境变量 JAX_NUMPY_RANK_PROMOTION 来设置该选项,例如 JAX_NUMPY_RANK_PROMOTION='warn'。最后,在使用 absl-py 时,可以通过命令行标志设置该选项。