jax.numpy.cross#
- jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#
计算两个数组的(批量)叉积。
numpy.cross()
的 JAX 实现。这将计算二维或三维叉积,
\[c = a \times b\]在 3 维中,
c
是长度为 3 的数组。在 2 维中,c
是标量。- 参数:
- 返回:
包含
a
和b
沿指定轴的(批量)叉积的数组c
。
参见
jax.numpy.linalg.cross()
:一个数组 API 兼容函数,用于计算 3 向量的叉积。
示例
二维叉积返回标量
>>> a = jnp.array([1, 2]) >>> b = jnp.array([3, 4]) >>> jnp.cross(a, b) Array(-2, dtype=int32)
三维叉积返回长度为 3 的向量
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.cross(a, b) Array([-3, 6, -3], dtype=int32)
对于多维输入,默认情况下,叉积沿最后一个轴计算。这是一个批量三维叉积,对输入的行进行运算
>>> a = jnp.array([[1, 2, 3], ... [3, 4, 3]]) >>> b = jnp.array([[2, 3, 2], ... [4, 5, 6]]) >>> jnp.cross(a, b) Array([[-5, 4, -1], [ 9, -6, -1]], dtype=int32)
指定 axis=0 使其成为批量二维叉积,对输入的列进行运算
>>> jnp.cross(a, b, axis=0) Array([-2, -2, 12], dtype=int32)
等效地,我们可以独立指定输入
a
和b
以及输出c
的轴>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0) Array([-2, -2, 12], dtype=int32)