jax.numpy.cross#

jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#

计算两个数组的(批量)叉积。

numpy.cross() 的 JAX 实现。

这将计算二维或三维叉积,

\[c = a \times b\]

在 3 维中,c 是长度为 3 的数组。在 2 维中,c 是标量。

参数:
  • a – N 维数组。 a.shape[axisa] 指示叉积的维度,并且必须为 2 或 3。

  • b – N 维数组。必须具有 b.shape[axisb] == a.shape[axisb],并且 ab 的其他维度必须是广播兼容的。

  • axisa (int) – 指定 a 的轴,沿该轴计算叉积。

  • axisb (int) – 指定 b 的轴,沿该轴计算叉积。

  • axisc (int) – 指定 c 的轴,叉积结果将存储在该轴上。

  • axis (int | None) – 如果指定,则此参数将使用单个值覆盖 axisaaxisbaxisc

返回:

包含 ab 沿指定轴的(批量)叉积的数组 c

参见

示例

二维叉积返回标量

>>> a = jnp.array([1, 2])
>>> b = jnp.array([3, 4])
>>> jnp.cross(a, b)
Array(-2, dtype=int32)

三维叉积返回长度为 3 的向量

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.cross(a, b)
Array([-3,  6, -3], dtype=int32)

对于多维输入,默认情况下,叉积沿最后一个轴计算。这是一个批量三维叉积,对输入的行进行运算

>>> a = jnp.array([[1, 2, 3],
...                [3, 4, 3]])
>>> b = jnp.array([[2, 3, 2],
...                [4, 5, 6]])
>>> jnp.cross(a, b)
Array([[-5,  4, -1],
       [ 9, -6, -1]], dtype=int32)

指定 axis=0 使其成为批量二维叉积,对输入的列进行运算

>>> jnp.cross(a, b, axis=0)
Array([-2, -2, 12], dtype=int32)

等效地,我们可以独立指定输入 ab 以及输出 c 的轴

>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
Array([-2, -2, 12], dtype=int32)