jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[源代码]#

移除数组中一个或多个长度为 1 的轴

JAX 对 numpy.sqeeze() 的实现,通过 jax.lax.squeeze() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | Sequence[int] | None) – 指定要移除的轴的整数或整数序列。如果指定的任何轴的长度不为 1,则会引发错误。如果未指定,则移除 a 中所有长度为 1 的轴。

返回:

移除长度为 1 的轴的 a 的副本。

返回类型:

Array

注意事项

numpy.squeeze() 不同,jax.numpy.squeeze() 会返回输入数组的副本而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉此类副本,因此这在实践中不会对性能产生影响。

另请参阅

示例

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

移除所有长度为 1 的维度

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

显式指定轴时的等效操作

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

尝试移除非单位轴会引发错误

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

为了方便起见,此功能也可通过 jax.Array.squeeze() 方法获得

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)