jax.numpy.resize#
- jax.numpy.resize(a, new_shape)[源代码]#
返回一个具有指定形状的新数组。
JAX 对
numpy.resize()的实现。- 参数:
a (ArrayLike) – 输入数组或标量。
new_shape (Shape) – int 或 int 元组。指定重塑后数组的形状。
- 返回:
一个具有指定形状的重塑后数组。如果重塑后数组大于原始数组,则
a的元素会在重塑后数组中重复。- 返回类型:
另请参阅
jax.numpy.reshape(): 返回数组的重塑副本。jax.numpy.repeat(): 从重复的元素构建数组。
示例
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> jnp.resize(x, (3, 3)) Array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32) >>> jnp.resize(x, (3, 4)) Array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 1, 2, 3]], dtype=int32) >>> jnp.resize(4, (3, 2)) Array([[4, 4], [4, 4], [4, 4]], dtype=int32, weak_type=True)