jax.numpy.broadcast_shapes#
- jax.numpy.broadcast_shapes(*shapes)[源代码]#
将输入形状广播到公共输出形状。
JAX 实现的
numpy.broadcast_shapes()
。JAX 使用 NumPy 样式的广播规则,您可以在NumPy 广播中阅读更多相关信息。- 参数:
shapes – 0 个或多个以整数序列形式指定的形状
- 返回:
广播后的形状,以整数元组形式表示。
另请参阅
jax.numpy.broadcast_arrays()
: 将数组广播到公共形状。jax.numpy.broadcast_to()
: 将数组广播到指定的形状。
示例
一些兼容的形状
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
不兼容的形状
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]