jax.Array.at#
- abstract property Array.at[源代码]#
用于索引更新功能的辅助属性。
at属性提供了原地修改数组的函数式纯粹等价物。具体来说
替代语法
等效的原地表达式
x = x.at[idx].set(y)x[idx] = yx = x.at[idx].add(y)x[idx] += yx = x.at[idx].subtract(y)x[idx] -= yx = x.at[idx].multiply(y)x[idx] *= yx = x.at[idx].divide(y)x[idx] /= yx = x.at[idx].power(y)x[idx] **= yx = x.at[idx].min(y)x[idx] = minimum(x[idx], y)x = x.at[idx].max(y)x[idx] = maximum(x[idx], y)x = x.at[idx].apply(ufunc)ufunc.at(x, idx)x = x.at[idx].get()x = x[idx]所有
x.at表达式都不会修改原始的x;相反,它们会返回一个修改后的x的副本。然而,在jit()编译的函数内部,像x = x.at[idx].set(y)这样的表达式可以保证被原地应用。与 NumPy 的原地操作(例如
x[idx] += y)不同,如果多个索引指向同一个位置,所有的更新都会被应用(NumPy 只会应用最后一个更新,而不是应用所有更新)。冲突更新的应用顺序是实现定义的,并且可能是非确定性的(例如,由于某些硬件平台上的并发)。默认情况下,JAX 假定所有索引都在边界内。可以通过
mode参数指定其他越界索引语义(见下文)。- 参数:
mode –
指定越界索引模式的字符串。选项是
"promise_in_bounds":(默认)用户承诺索引在边界内。不会执行额外的检查。实际上,这意味着get()中的越界索引将被截断,而set()、add()等中的越界索引将被丢弃。"clip": 将越界索引夹紧到有效范围。"drop": 忽略越界索引。"fill":"drop"的别名。对于 get(),可选的fill_value参数指定当mode为'fill'时返回的值。否则忽略。默认值为非精确类型的NaN,有符号类型的最大负值,无符号类型的最大正值,以及布尔值的True。
有关更多详细信息,请参阅
jax.lax.GatherScatterMode。wrap_negative_indices – 如果为 True(默认),则负索引指示从数组末尾开始的位置,类似于 Python 和 NumPy 的索引。如果为 False,则负索引被视为越界,并根据
mode参数进行处理。fill_value – 仅适用于
get()方法:当mode为'fill'时,为越界切片返回的填充值。否则忽略。默认值为非精确类型的NaN,有符号类型的最大负值,无符号类型的最大正值,以及布尔值的True。indices_are_sorted – 如果为 True,实现将假定传递给
at[]的(规范化后的)索引是升序排列的,这可以在某些后端上实现更高效的执行。如果为 True 但索引实际上并未排序,则输出未定义。unique_indices – 如果为 True,实现将假定传递给
at[]的(规范化后的)索引是唯一的,这可以在某些后端上实现更高效的执行。如果为 True 但索引实际上并非唯一,则输出未定义。
示例
>>> x = jnp.arange(5.0) >>> x Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].get() Array(2., dtype=float32) >>> x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32)
默认情况下,越界索引在更新中被忽略,但此行为可以通过
mode参数控制>>> x.at[10].add(10) # dropped Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') # clipped Array([ 0., 1., 2., 3., 14.], dtype=float32)
对于
get(),越界索引默认被截断>>> x.at[20].get() # out-of-bounds indices clipped Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value Array(-1., dtype=float32)
负索引从数组末尾开始计数,但可以通过将
wrap_negative_indices = False来禁用此行为>>> x.at[-1].set(99) Array([ 0., 1., 2., 3., 99.], dtype=float32) >>> x.at[-1].set(99, wrap_negative_indices=False, mode='drop') # dropped! Array([0., 1., 2., 3., 4.], dtype=float32)