jax.numpy.apply_along_axis#

jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[源]#

沿某个轴将函数应用于一维数组切片。

numpy.apply_along_axis() 的 JAX 实现。尽管 NumPy 是迭代实现此功能的,但 JAX 是通过 jax.vmap() 实现此功能的,因此 func1d 必须与 vmap 兼容。

参数:
  • func1d (可调用对象) – 一个可调用函数,其签名是 func1d(arr, /, *args, **kwargs),其中 *args**kwargs 是传递给 apply_along_axis() 的额外位置参数和关键字参数。

  • axis (整型) – 沿其应用函数的整数轴。

  • arr (类数组) – 要应用函数的数组。

  • args – 额外的​位置参数和关键字参数会传递给 func1d

  • kwargs – 额外的​位置参数和关键字参数会传递给 func1d

返回:

沿指定轴应用 `func1d` 的结果。

返回类型:

数组

另请参阅

示例

一个二维的简单示例,函数按行或按列应用

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> def func1d(x):
...   return jnp.sum(x ** 2)
>>> jnp.apply_along_axis(func1d, 0, x)
Array([17, 29, 45], dtype=int32)
>>> jnp.apply_along_axis(func1d, 1, x)
Array([14, 77], dtype=int32)

对于二维输入,这可以通过使用 jax.vmap() 等效地表达,但请注意,`vmap` 指定的是映射轴而不是应用轴

>>> jax.vmap(func1d, in_axes=1)(x)  # same as applying along axis 0
Array([17, 29, 45], dtype=int32)
>>> jax.vmap(func1d, in_axes=0)(x)  # same as applying along axis 1
Array([14, 77], dtype=int32)

对于三维输入,apply_along_axis() 等同于对两个维度进行映射

>>> x_3d = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.apply_along_axis(func1d, 2, x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)
>>> jax.vmap(jax.vmap(func1d))(x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)

应用的函数还可以接受任意位置参数或关键字参数,这些参数应作为额外参数直接传递给 apply_along_axis()

>>> def func1d(x, exponent):
...   return jnp.sum(x ** exponent)
>>> jnp.apply_along_axis(func1d, 0, x, exponent=3)
Array([ 65, 133, 243], dtype=int32)