jax.numpy.hstack#
- jax.numpy.hstack(tup, dtype=None)[源代码]#
水平堆叠数组。
numpy.hstack()
的 JAX 实现。对于一个或多个维度的数组,这等效于
jax.numpy.concatenate()
,其中axis=1
。- 参数:
- 返回值:
堆叠结果。
- 返回类型:
另请参阅
jax.numpy.stack()
: 沿任意轴堆叠jax.numpy.concatenate()
: 沿现有轴连接。jax.numpy.vstack()
: 垂直堆叠,即沿轴 0。jax.numpy.dstack()
: 深度堆叠,即沿轴 2。
示例
标量值
>>> jnp.hstack([1, 2, 3]) Array([1, 2, 3], dtype=int32, weak_type=True)
1D 数组
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.hstack([x, y]) Array([0., 1., 2., 1., 1., 1.], dtype=float32)
2D 数组
>>> x = x.reshape(3, 1) >>> y = y.reshape(3, 1) >>> jnp.hstack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)