jax.numpy.broadcast_to#
- jax.numpy.broadcast_to(array, shape, *, out_sharding=None)[source]#
将数组广播到指定形状。
JAX 对
numpy.broadcast_to()
的实现。JAX 使用 NumPy 风格的广播规则,您可以在 NumPy 广播机制 中了解更多信息。- 参数:
array (ArrayLike) – 要广播的数组。
shape (DimSize | Shape) – 数组将要广播到的形状。
out_sharding (NamedSharding | P | None)
- 返回:
广播到指定形状的数组副本。
- 返回类型:
另请参阅
jax.numpy.broadcast_arrays()
: 将数组广播到共同形状。jax.numpy.broadcast_shapes()
: 将输入形状广播到共同形状。
示例
>>> x = jnp.int32(1) >>> jnp.broadcast_to(x, (1, 4)) Array([[1, 1, 1, 1]], dtype=int32)
>>> x = jnp.array([1, 2, 3]) >>> jnp.broadcast_to(x, (2, 3)) Array([[1, 2, 3], [1, 2, 3]], dtype=int32)
>>> x = jnp.array([[2], [4]]) >>> jnp.broadcast_to(x, (2, 4)) Array([[2, 2, 2, 2], [4, 4, 4, 4]], dtype=int32)