jax.linear_transpose#

jax.linear_transpose(fun, *primals, reduce_axes=())[source]#

对一个保证是线性的函数进行转置。

对于线性函数,此变换等同于 vjp(),但避免了计算正向传播的开销。

转置函数的输出将始终具有与 primals 完全相同的数据类型(dtypes),即使某些值被截断(例如,从复数到浮点数,或从 float64 到 float32)。为避免截断,请在 primals 中使用与转置函数所需输出的完整范围相匹配的数据类型。不支持整数数据类型。

参数:
  • fun (可调用对象) – 要转置的线性函数。

  • *primals – 用于评估 fun(*primals) 的形状/数据类型的,由数组、标量或(嵌套的)标准 Python 容器(元组、列表、字典、命名元组,即 pytrees)组成的定位参数元组。这些参数可以是实际的标量/N维数组,但并非必需:只访问 shapedtype 属性。请参阅下面的示例。(请注意,鸭子类型对象不能是命名元组,因为它们被视为标准 Python 容器。)

返回:

一个计算 fun 转置的可调用对象。此函数的有效输入必须具有与 fun(*primals) 结果相同的形状/数据类型/结构。输出将是一个元组,其形状/数据类型/结构与 primals 相同。

返回类型:

可调用对象

>>> import jax
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))