jax.lax.reshape#
- jax.lax.reshape(operand, new_sizes, dimensions=None, *, out_sharding=None)[源代码]#
封装 XLA 的 Reshape 运算符。
对于插入/删除大小为 1 的维度,建议使用
lax.squeeze
/lax.expand_dims
。这些保留了关于轴标识的信息,这些信息可能对高级转换规则有用。- 参数:
operand (ArrayLike) – 要重塑的数组。
new_sizes (Shape) – 指定结果形状的整数序列。最终数组的大小必须与输入的大小匹配。
dimensions (Sequence[int] | None) – 指定输入形状的排列顺序的可选整数序列。如果指定,则长度必须与
operand.shape
匹配。out_sharding (NamedSharding | P | None)
- 返回:
重塑后的数组。
- 返回类型:
out
示例
从一维到二维的简单重塑
>>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y Array([[0, 1, 2], [3, 4, 5]], dtype=int32)
重塑回一维
>>> reshape(y, (6,)) Array([0, 1, 2, 3, 4, 5], dtype=int32)
通过维度排列重塑为一维
>>> reshape(y, (6,), (1, 0)) Array([0, 3, 1, 4, 2, 5], dtype=int32)