jax.numpy.stack#
- jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[源代码]#
沿新轴连接数组。
JAX 实现的
numpy.stack()
。- 参数:
- 返回:
堆叠结果。
- 返回类型:
另请参阅
jax.numpy.unstack()
:stack
的逆运算。jax.numpy.concatenate()
: 沿现有轴连接。jax.numpy.vstack()
: 垂直堆叠,即沿轴 0 堆叠。jax.numpy.hstack()
: 水平堆叠,即沿轴 1 堆叠。jax.numpy.dstack()
: 深度堆叠,即沿轴 2 堆叠。jax.numpy.column_stack()
: 堆叠列。
示例
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack()
执行逆运算>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)