jax.numpy.ndarray.at#

abstract property ndarray.at[来源]#

用于索引更新功能的辅助属性。

at 属性提供了一种纯函数式的方法来等效地进行原地数组修改。

具体来说

替代语法

等效的原地表达式

x = x.at[idx].set(y)

x[idx] = y

x = x.at[idx].add(y)

x[idx] += y

x = x.at[idx].subtract(y)

x[idx] -= y

x = x.at[idx].multiply(y)

x[idx] *= y

x = x.at[idx].divide(y)

x[idx] /= y

x = x.at[idx].power(y)

x[idx] **= y

x = x.at[idx].min(y)

x[idx] = minimum(x[idx], y)

x = x.at[idx].max(y)

x[idx] = maximum(x[idx], y)

x = x.at[idx].apply(ufunc)

ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

所有 x.at 表达式都不会修改原始的 x;相反,它们会返回一个修改后的 x 的副本。然而,在 jit() 编译的函数内部,诸如 x = x.at[idx].set(y) 这样的表达式保证会被原地应用。

与 NumPy 的原地操作(如 x[idx] += y)不同,如果多个索引指向同一个位置,所有更新都会被应用(NumPy 只会应用最后一个更新,而不是应用所有更新)。冲突更新的应用顺序是实现定义的,并且可能是非确定性的(例如,在某些硬件平台上由于并发)。

默认情况下,JAX 假定所有索引都在边界内。可以通过 mode 参数指定替代的越界索引语义(见下文)。

参数:
  • mode

    指定越界索引模式的字符串。选项包括:

    • "promise_in_bounds":(默认)用户保证索引在边界内。不会进行额外的检查。实际上,这意味着 get() 中的越界索引将被裁剪,而 set()add() 等中的越界索引将被丢弃。

    • "clip": 将越界索引限制在有效范围内。

    • "drop": 忽略越界索引。

    • "fill": "drop" 的别名。对于 get(),可选的 fill_value 参数指定当 mode'fill' 时返回的值。否则被忽略。

    有关更多详细信息,请参阅 jax.lax.GatherScatterMode

  • wrap_negative_indices – 如果为 True(默认),则负数索引表示从数组末尾开始的位置,类似于 Python 和 NumPy 的索引。如果为 False,则负数索引被视为越界,并根据 mode 参数进行处理。

  • fill_value – 仅适用于 get() 方法:当 mode'fill' 时,用于返回越界切片的填充值。否则被忽略。对于非精确类型,默认为 NaN;对于有符号类型,默认为最大负值;对于无符号类型,默认为最大正值;对于布尔类型,默认为 True

  • indices_are_sorted – 如果为 True,则实现将假定传递给 at[] 的(已规范化的)索引是升序排序的,这可能在某些后端上导致更高效的执行。如果为 True 但索引实际上并未排序,则输出是未定义的。

  • unique_indices – 如果为 True,则实现将假定传递给 at[] 的(已规范化的)索引是唯一的,这可能在某些后端上导致更高效的执行。如果为 True 但索引实际上并非唯一,则输出是未定义的。

示例

>>> x = jnp.arange(5.0)
>>> x
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[2].get()
Array(2., dtype=float32)
>>> x.at[2].add(10)
Array([ 0.,  1., 12.,  3.,  4.], dtype=float32)

默认情况下,越界索引在更新时会被忽略,但此行为可以通过 mode 参数控制。

>>> x.at[10].add(10)  # dropped
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[20].add(10, mode='clip')  # clipped
Array([ 0.,  1.,  2.,  3., 14.], dtype=float32)

对于 get(),默认情况下越界索引会被裁剪。

>>> x.at[20].get()  # out-of-bounds indices clipped
Array(4., dtype=float32)
>>> x.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN
Array(nan, dtype=float32)
>>> x.at[20].get(mode='fill', fill_value=-1)  # custom fill value
Array(-1., dtype=float32)

负数索引从数组末尾开始计数,但可以通过将 wrap_negative_indices = False 来禁用此行为。

>>> x.at[-1].set(99)
Array([ 0.,  1.,  2.,  3., 99.], dtype=float32)
>>> x.at[-1].set(99, wrap_negative_indices=False, mode='drop')  # dropped!
Array([0., 1., 2., 3., 4.], dtype=float32)