jax.numpy.column_stack#
- jax.numpy.column_stack(tup)[源代码]#
按列堆叠数组。
JAX 实现
numpy.column_stack()。对于有两个或多个维度的数组,这等效于
jax.numpy.concatenate()配合axis=1。- 参数:
- 返回:
堆叠结果。
- 返回类型:
另请参阅
jax.numpy.stack(): 沿任意轴堆叠jax.numpy.concatenate(): 沿现有轴进行连接。jax.numpy.vstack(): 垂直堆叠,即沿轴 0 堆叠。jax.numpy.hstack(): 水平堆叠,即沿轴 1 堆叠。jax.numpy.dstack(): 深度堆叠,即沿轴 2 堆叠。
示例
标量值
>>> jnp.column_stack([1, 2, 3]) Array([[1, 2, 3]], dtype=int32, weak_type=True)
一维数组
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.column_stack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
二维数组
>>> x = x.reshape(3, 1) >>> y = y.reshape(3, 1) >>> jnp.column_stack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)