jax.ref.get#
- jax.ref.get(ref, idx=None)[源]#
从 Ref 中读取一个值。
这等同于对于 NumPy 风格的索引器
idx
的ref[idx]
。有关可变数组 Ref 的更多信息,请参阅 Ref 指南。- 参数:
ref (Any) – 一个
jax.ref.Ref
对象。idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器
- 返回:
一个
jax.Array
对象(注意,不是jax.ref.Ref
),其中包含可变引用的索引元素。- 返回类型:
示例
>>> import jax >>> ref = jax.new_ref(jax.numpy.arange(5)) >>> jax.ref.get(ref, slice(1, 3)) Array([1, 2], dtype=int32)
通过索引语法实现等效操作
>>> ref[1:3] Array([1, 2], dtype=int32)
使用
...
来提取整个缓冲区>>> ref[...] Array([0, 1, 2, 3, 4], dtype=int32)