jax.ref.set#

jax.ref.set(ref, idx, value)[源代码]#

原地修改 Ref 中的值。

这等同于对于 NumPy 风格的索引器 idx,执行 ref[idx] = value。有关可变数组 Ref 的更多信息,请参阅 Ref 指南

参数:
  • ref (AbstractRef | TransformedRef) – 一个 jax.ref.Ref 对象。返回时,缓冲区将通过此操作进行修改。

  • idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器

  • value (Array) – 一个 jax.Array 对象(注意,不是 jax.ref.Ref)包含要设置到数组中的值。

返回:

返回类型:

示例

>>> import jax
>>> ref = jax.new_ref(jax.numpy.zeros(5))
>>> jax.ref.set(ref, 1, 10.0)
>>> ref
Ref([ 0., 10.,  0.,  0.,  0.], dtype=float32)

通过索引语法实现等效操作

>>> ref = jax.new_ref(jax.numpy.zeros(5))
>>> ref[1] = 10.0
>>> ref
Ref([ 0., 10.,  0.,  0.,  0.], dtype=float32)

使用 ... 来设置标量 ref 的值

>>> ref = jax.new_ref(jax.numpy.int32(0))
>>> ref[...] = 4
>>> ref
Ref(4, dtype=int32)