jax.numpy.arange#
- jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None, out_sharding=None)[源代码]#
创建一组等间距的值。
JAX 对
numpy.arange()的实现,使用jax.lax.iota()实现。类似于 Python 的
range()函数,它可以接受几种不同的位置签名。jnp.arange(stop):生成从 0 到stop的值,步长为 1。jnp.arange(start, stop):生成从start到stop的值,步长为 1。jnp.arange(start, stop, step):生成从start到stop的值,步长为step。
与 Python 的
range()函数一样,起始值是包含在内的,而停止值是不包含在内的。- 参数:
start (ArrayLike | DimSize) – 区间的起始值,包含在内。
stop (ArrayLike | DimSize | None) – 可选的区间结束值,不包含在内。如果未指定,则
(start, stop) = (0, start)step (ArrayLike | None) – 可选的区间步长。默认为 1。
dtype (DTypeLike | None) – 返回数组的可选数据类型;如果未指定,则通过 start、stop 和 step 的类型提升来确定。
device (xc.Device | Sharding | None) – (可选) 创建的数组将被提交到的
Device或Sharding。out_sharding (NamedSharding | P | None) – (可选) 创建的数组将被提交到的
NamedSharding或P。如果使用显式分片(https://jax.net.cn/en/latest/notebooks/explicit-sharding.html),请使用 out_sharding 参数。
- 返回:
从
start到stop,以step分隔的等间距值数组。- 返回类型:
注意
使用带有浮点
step参数的arange可能会由于浮点误差的累积而导致意外结果,尤其是在使用float8_*和bfloat16等低精度数据类型时。为避免精度误差,请考虑生成整数范围,然后将其缩放到所需范围。例如,而不是这样做:jnp.arange(-1, 1, 0.01, dtype='bfloat16')
生成一个整数序列并对其进行缩放可能更准确:
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
示例
单参数版本仅指定
stop值。>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
传递浮点
stop值会导致浮点结果。>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
双参数版本指定
start和stop,step=1。>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
三参数版本指定
start、stop和step。>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
另请参阅
jax.numpy.linspace():生成固定数量的等间距值。jax.lax.iota():直接在 XLA 中生成整数序列。