jax.numpy.broadcast_to#

jax.numpy.broadcast_to(array, shape, *, out_sharding=None)[source]#

将数组广播到指定形状。

JAX 对 numpy.broadcast_to() 的实现。JAX 使用 NumPy 风格的广播规则,您可以在 NumPy 广播机制 中了解更多信息。

参数:
  • array (ArrayLike) – 要广播的数组。

  • shape (DimSize | Shape) – 数组将要广播到的形状。

  • out_sharding (NamedSharding | P | None)

返回:

广播到指定形状的数组副本。

返回类型:

Array

另请参阅

示例

>>> x = jnp.int32(1)
>>> jnp.broadcast_to(x, (1, 4))
Array([[1, 1, 1, 1]], dtype=int32)
>>> x = jnp.array([1, 2, 3])
>>> jnp.broadcast_to(x, (2, 3))
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> x = jnp.array([[2], [4]])
>>> jnp.broadcast_to(x, (2, 4))
Array([[2, 2, 2, 2],
       [4, 4, 4, 4]], dtype=int32)