jax.lax.broadcast_shapes#
- jax.lax.broadcast_shapes(*shapes)[source]#
返回从 shapes 的 NumPy 广播产生的形状。
这遵循 NumPy 广播的规则。
- 参数:
shapes – 一个或多个整数元组,包含要广播的数组的形状。
- 返回:
一个表示广播后形状的整数元组。
- 引发:
ValueError – 如果形状不兼容广播。
另请参阅
jax.numpy.broadcast_shapes()
: JAX NumPy 命名空间中的类似 API
示例
一些广播兼容形状的例子
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
尝试广播不兼容形状时出错
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]