jax.lax.broadcast_shapes#
- jax.lax.broadcast_shapes(*shapes)[源代码]#
返回 NumPy 广播 shapes 后产生的形状。
这遵循 NumPy 广播 的规则。
- 参数:
shapes – 一个或多个整数元组,包含要广播的数组的形状。
- 返回:
一个整数元组,表示广播后的形状。
- 引发:
ValueError – 如果形状不兼容广播。
另请参阅
jax.numpy.broadcast_shapes()
:JAX NumPy 命名空间中类似的 API
示例
一些广播兼容形状的示例
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
尝试广播不兼容形状时出错
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]