jax.numpy.vectorize#

jax.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)[source]#

定义一个支持广播的向量化函数。

vectorize() 是一个便捷封装,用于定义带广播功能的向量化函数,其风格类似于 NumPy 的泛化通用函数。它允许定义在任意前导维度上自动重复的函数,而无需函数实现关心如何处理高维输入。

jax.numpy.vectorize() 具有与 numpy.vectorize 相同的接口,但它是一种自动批处理变换(vmap())的语法糖,而不是一个 Python 循环。这应该会显著提高效率,但其实现必须使用作用于 JAX 数组的函数来编写。

参数:
  • pyfunc – 要向量化的函数。

  • excluded – 可选的整数集合,表示函数不会对其进行向量化的位置参数。这些参数将未经修改地直接传递给 pyfunc

  • signature – 可选的泛化通用函数签名,例如,(m,n),(n)->(m) 用于向量化矩阵-向量乘法。如果提供,pyfunc 将使用(并预期返回)形状由相应核心维度大小确定的数组进行调用。默认情况下,pyfunc 被假定接受标量数组作为输入;如果 signatureNone,则 pyfunc 可以生成任意形状的输出。

返回:

给定函数的向量化版本。

示例

这里有一些如何使用 vectorize() 编写向量化线性代数例程的示例

>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
...   assert a.shape == b.shape and a.ndim == b.ndim == 1
...   return jnp.array([a[1] * b[2] - a[2] * b[1],
...                     a[2] * b[0] - a[0] * b[2],
...                     a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
...   assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
...   return matrix @ vector

这些函数只被编写用于处理 1D 或 2D 数组(assert 语句永远不会被违反),但通过 vectorize,它们支持具有 NumPy 风格广播的任意维度输入,例如,

>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3))  
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)

请注意,这与 jnp.matmul 具有不同的语义

>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3)))  
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].