jax.numpy.expand_dims#

jax.numpy.expand_dims(a, axis)[源]#

在数组中插入长度为 1 的维度

JAX 对 numpy.expand_dims() 的实现,通过 jax.lax.expand_dims() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | Sequence[int]) – 一个整数或整数序列,指定要添加的轴的位置。

返回:

a 的副本,带有添加的维度。

返回类型:

数组

注意事项

numpy.expand_dims() 不同的是,jax.numpy.expand_dims() 将返回输入数组的一个副本,而不是视图。然而,在 JIT 下,编译器在可能的情况下会优化掉这些副本,因此这在实践中不会对性能产生影响。

另请参阅

示例

>>> x = jnp.array([1, 2, 3])
>>> x.shape
(3,)

扩展前导维度

>>> jnp.expand_dims(x, 0)
Array([[1, 2, 3]], dtype=int32)
>>> _.shape
(1, 3)

扩展尾随维度

>>> jnp.expand_dims(x, 1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>> _.shape
(3, 1)

扩展多个维度

>>> jnp.expand_dims(x, (0, 1, 3))
Array([[[[1],
         [2],
         [3]]]], dtype=int32)
>>> _.shape
(1, 1, 3, 1)

通过使用 None 进行索引,也可以更简洁地扩展维度

>>> x[None]  # equivalent to jnp.expand_dims(x, 0)
Array([[1, 2, 3]], dtype=int32)
>>> x[:, None]  # equivalent to jnp.expand_dims(x, 1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>> x[None, None, :, None]  # equivalent to jnp.expand_dims(x, (0, 1, 3))
Array([[[[1],
         [2],
         [3]]]], dtype=int32)