jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[源代码]#
移除数组中一个或多个长度为 1 的轴
JAX 对
numpy.sqeeze()的实现,通过jax.lax.squeeze()实现。- 参数:
- 返回:
移除长度为 1 的轴的
a的副本。- 返回类型:
注意事项
与
numpy.squeeze()不同,jax.numpy.squeeze()会返回输入数组的副本而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉此类副本,因此这在实践中不会对性能产生影响。另请参阅
jax.numpy.expand_dims():squeeze的逆操作:添加长度为 1 的维度。jax.Array.squeeze(): 通过数组方法实现相同的功能。jax.lax.squeeze(): 对应的 XLA API。jax.numpy.ravel(): 将数组展平成一维形状。jax.numpy.reshape(): 通用的数组重塑。
示例
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
移除所有长度为 1 的维度
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
显式指定轴时的等效操作
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
尝试移除非单位轴会引发错误
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
为了方便起见,此功能也可通过
jax.Array.squeeze()方法获得>>> x.squeeze() Array([0, 1, 2], dtype=int32)