jax.hessian#

jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

fun 函数的海森矩阵(以稠密数组形式)。

参数:
  • fun (Callable) – 要计算其海森矩阵的函数。其由 argnums 指定位置的参数应为数组、标量或它们的标准 Python 容器。它应返回数组、标量或它们的标准 Python 容器。

  • argnums (int | Sequence[int]) – 可选参数,整数或整数序列。指定要对其进行微分的位置参数(默认为 0)。

  • has_aux (bool) – 可选参数,布尔值。指示 fun 是否返回一对值,其中第一个元素被认为是待微分数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic (bool) – 可选参数,布尔值。指示 fun 是否承诺是全纯函数。默认为 False。

返回:

一个与 fun 具有相同参数的函数,用于计算 fun 的海森矩阵。

返回类型:

可调用对象

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian() 是海森矩阵常见定义的一种推广,它支持嵌套的 Python 容器(即 pytree)作为输入和输出。jax.hessian(fun)(x) 的树状结构是通过将 fun(x) 的结构与 x 结构的两个副本的树状乘积结合而形成的。两个树状结构的树状乘积是通过将第一个树的每个叶子替换为第二个树的副本而形成的。例如

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

因此,jax.hessian(fun)(x) 树状结构中的每个叶子都对应于 fun(x) 的一个叶子以及 x 的一对叶子。对于 jax.hessian(fun)(x) 中的每个叶子,如果 fun(x) 对应的数组叶子形状为 (out_1, out_2, ...),并且 x 对应的数组叶子形状分别为 (in_1_1, in_1_2, ...)(in_2_1, in_2_2, ...),那么海森矩阵的叶子形状为 (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)。换句话说,Python 树状结构表示海森矩阵的块结构,其中块由输入和输出 pytree 确定。

特别是,当函数输入 x 和输出 fun(x) 各自是单个数组时(如上例 g 所示),将生成一个数组(不涉及任何 pytree)。如果 fun(x) 的形状为 (out1, out2, ...),并且 x 的形状为 (in1, in2, ...),那么 jax.hessian(fun)(x) 的形状为 (out1, out2, ..., in1, in2, ..., in1, in2, ...)。要将 pytree 展平为一维向量,可以考虑使用 jax.flatten_util.flatten_pytree()