jax.ref.get#

jax.ref.get(ref, idx=None)[源]#

从 Ref 中读取一个值。

这等同于对于 NumPy 风格的索引器 idxref[idx]。有关可变数组 Ref 的更多信息,请参阅 Ref 指南

参数:
  • ref (Any) – 一个 jax.ref.Ref 对象。

  • idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器

返回:

一个 jax.Array 对象(注意,不是 jax.ref.Ref),其中包含可变引用的索引元素。

返回类型:

Array

示例

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> jax.ref.get(ref, slice(1, 3))
Array([1, 2], dtype=int32)

通过索引语法实现等效操作

>>> ref[1:3]
Array([1, 2], dtype=int32)

使用 ... 来提取整个缓冲区

>>> ref[...]
Array([0, 1, 2, 3, 4], dtype=int32)