jax.numpy.cross#
- jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[源代码]#
计算两个数组的(批量)叉积。
numpy.cross()
的 JAX 实现。计算二维或三维叉积,
\[c = a \times b\]在 3 维空间中,
c
是一个长度为 3 的数组。在 2 维空间中,c
是一个标量。- 参数:
- 返回:
数组
c
,包含a
和b
沿指定轴的(批处理)叉积。
另请参阅
jax.numpy.linalg.cross()
:一个与数组 API 兼容的函数,用于计算 3 维向量的叉积。
示例
二维叉积返回一个标量
>>> a = jnp.array([1, 2]) >>> b = jnp.array([3, 4]) >>> jnp.cross(a, b) Array(-2, dtype=int32)
三维叉积返回一个长度为 3 的向量
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.cross(a, b) Array([-3, 6, -3], dtype=int32)
对于多维输入,默认情况下沿最后一个轴计算叉积。这是一个批处理的三维叉积,对输入的行进行操作
>>> a = jnp.array([[1, 2, 3], ... [3, 4, 3]]) >>> b = jnp.array([[2, 3, 2], ... [4, 5, 6]]) >>> jnp.cross(a, b) Array([[-5, 4, -1], [ 9, -6, -1]], dtype=int32)
指定 axis=0 使其成为批处理的二维叉积,对输入的列进行操作
>>> jnp.cross(a, b, axis=0) Array([-2, -2, 12], dtype=int32)
等效地,我们可以独立指定输入
a
和b
的轴以及输出c
的轴>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0) Array([-2, -2, 12], dtype=int32)