jax.numpy.broadcast_shapes#

jax.numpy.broadcast_shapes(*shapes)[源代码]#

将输入形状广播到公共输出形状。

JAX 实现的 numpy.broadcast_shapes()。JAX 使用 NumPy 样式的广播规则,您可以在NumPy 广播中阅读更多相关信息。

参数:

shapes – 0 个或多个以整数序列形式指定的形状

返回:

广播后的形状,以整数元组形式表示。

另请参阅

示例

一些兼容的形状

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

不兼容的形状

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]