秩提升警告#
NumPy 广播规则允许参数从一个秩(数组轴的数量)自动提升到另一个秩。这种行为在预期情况下可能很方便,但也可能导致意想不到的错误,即隐式的秩提升掩盖了底层的形状错误。
这是一个秩提升的例子
>>> from jax import numpy as jnp
>>> x = jnp.arange(12).reshape(4, 3)
>>> y = jnp.array([0, 1, 0])
>>> x + y
Array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]], dtype=int32)
为了避免潜在的意外,jax.numpy
是可配置的,这样需要秩提升的表达式可以导致警告、错误,或者像常规 NumPy 一样被允许。该配置选项名为 jax_numpy_rank_promotion
,并且它可以接受字符串值 allow
、warn
和 raise
。默认设置为 allow
,这允许秩提升而不会发出警告或错误。raise
设置会在秩提升时引发错误,而 warn
设置会在秩提升首次发生时发出警告。
秩提升可以通过 jax.numpy_rank_promotion()
上下文管理器在本地启用或禁用。
with jax.numpy_rank_promotion("warn"):
z = x + y
此配置也可以通过几种方式全局设置。其中一种是在代码中使用 jax.config
。
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
您还可以使用环境变量 JAX_NUMPY_RANK_PROMOTION
来设置此选项,例如设置为 JAX_NUMPY_RANK_PROMOTION='warn'
。最后,在使用 absl-py
时,该选项可以通过命令行标志设置。