jax.numpy.vectorize#
- jax.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)[source]#
定义一个带有广播机制的向量化函数。
vectorize()
是一个方便的包装器,用于定义带有广播机制的向量化函数,风格类似于 NumPy 的 generalized universal functions。它允许定义在任何前导维度上自动重复的函数,而无需函数的实现考虑如何处理更高维度的输入。jax.numpy.vectorize()
具有与numpy.vectorize
相同的接口,但它是自动批处理转换 (vmap()
) 的语法糖,而不是 Python 循环。这应该效率更高,但实现必须用作用于 JAX 数组的函数来编写。- 参数:
pyfunc – 要向量化的函数。
excluded – 可选的整数集合,表示函数不会被向量化的位置参数。这些参数将直接传递给
pyfunc
,不做修改。signature – 可选的广义 universal function 签名,例如,
(m,n),(n)->(m)
用于向量化矩阵-向量乘法。如果提供,pyfunc
将被调用,并期望返回形状由相应核心维度大小给出的数组。默认情况下,pyfunc 假定接受标量数组作为输入和输出。
- 返回:
给定函数的向量化版本。
示例
以下是一些关于如何使用
vectorize()
编写向量化线性代数例程的示例>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)') ... def cross_product(a, b): ... assert a.shape == b.shape and a.ndim == b.ndim == 1 ... return jnp.array([a[1] * b[2] - a[2] * b[1], ... a[2] * b[0] - a[0] * b[2], ... a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') ... def matrix_vector_product(matrix, vector): ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape ... return matrix @ vector
这些函数仅编写用于处理 1D 或 2D 数组(
assert
语句永远不会被违反),但通过 vectorize,它们支持具有 NumPy 风格广播的任意维度输入,例如:>>> cross_product(jnp.ones(3), jnp.ones(3)).shape (3,) >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2, 3) >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape (2, 2, 3) >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) Traceback (most recent call last): ValueError: input with shape (3,) does not have enough dimensions for all core dimensions ('n', 'k') on vectorized function with excluded=frozenset() and signature='(n,k),(k)->(k)' >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2,) >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape (4, 2)
请注意,这与 jnp.matmul 的语义不同
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) Traceback (most recent call last): TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].