jax.numpy.cross#

jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#

计算两个数组的(批量)叉积。

JAX 实现的 numpy.cross()

这会计算 2 维或 3 维叉积,

\[c = a \times b\]

在 3 维中,c 是长度为 3 的数组。在 2 维中,c 是标量。

参数:
  • a – N 维数组。 a.shape[axisa] 指示叉积的维度,必须为 2 或 3。

  • b – N 维数组。必须具有 b.shape[axisb] == a.shape[axisb],并且 ab 的其他维度必须是广播兼容的。

  • axisa (int) – 指定 a 沿其计算叉积的轴。

  • axisb (int) – 指定 b 沿其计算叉积的轴。

  • axisc (int) – 指定 c 沿其存储叉积结果的轴。

  • axis (int | None) – 如果指定,这将使用单个值覆盖 axisaaxisbaxisc

返回:

包含 ab 沿指定轴的(批量)叉积的数组 c

另请参阅

示例

2 维叉积返回一个标量

>>> a = jnp.array([1, 2])
>>> b = jnp.array([3, 4])
>>> jnp.cross(a, b)
Array(-2, dtype=int32)

3 维叉积返回一个长度为 3 的向量

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.cross(a, b)
Array([-3,  6, -3], dtype=int32)

对于多维输入,默认情况下沿最后一个轴计算叉积。这是一个批量 3 维叉积,作用于输入的行

>>> a = jnp.array([[1, 2, 3],
...                [3, 4, 3]])
>>> b = jnp.array([[2, 3, 2],
...                [4, 5, 6]])
>>> jnp.cross(a, b)
Array([[-5,  4, -1],
       [ 9, -6, -1]], dtype=int32)

指定 axis=0 使其成为批量 2 维叉积,作用于输入的列

>>> jnp.cross(a, b, axis=0)
Array([-2, -2, 12], dtype=int32)

等效地,我们可以独立指定输入 ab 以及输出 c 的轴

>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
Array([-2, -2, 12], dtype=int32)