jax.ref.set#
- jax.ref.set(ref, idx, value)[源代码]#
原地修改 Ref 中的值。
这等同于对于 NumPy 风格的索引器
idx
,执行ref[idx] = value
。有关可变数组 Ref 的更多信息,请参阅 Ref 指南。- 参数:
ref (AbstractRef | TransformedRef) – 一个
jax.ref.Ref
对象。返回时,缓冲区将通过此操作进行修改。idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器
value (Array) – 一个
jax.Array
对象(注意,不是jax.ref.Ref
)包含要设置到数组中的值。
- 返回:
无
- 返回类型:
无
示例
>>> import jax >>> ref = jax.new_ref(jax.numpy.zeros(5)) >>> jax.ref.set(ref, 1, 10.0) >>> ref Ref([ 0., 10., 0., 0., 0.], dtype=float32)
通过索引语法实现等效操作
>>> ref = jax.new_ref(jax.numpy.zeros(5)) >>> ref[1] = 10.0 >>> ref Ref([ 0., 10., 0., 0., 0.], dtype=float32)
使用
...
来设置标量 ref 的值>>> ref = jax.new_ref(jax.numpy.int32(0)) >>> ref[...] = 4 >>> ref Ref(4, dtype=int32)