jax.lax.broadcast_shapes#

jax.lax.broadcast_shapes(*shapes)[源代码]#

返回 NumPy 广播 shapes 后产生的形状。

这遵循 NumPy 广播 的规则。

参数:

shapes – 一个或多个整数元组,包含要广播的数组的形状。

返回:

一个整数元组,表示广播后的形状。

引发:

ValueError – 如果形状不兼容广播。

另请参阅

示例

一些广播兼容形状的示例

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

尝试广播不兼容形状时出错

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]