jax.default_matmul_precision#

jax.default_matmul_precision = <jax._src.config.State object>#

用于 jax_default_matmul_precision 配置选项的上下文管理器。

控制 32 位输入的默认矩阵乘法和卷积精度。

某些平台(如 TPU)为矩阵乘法和卷积计算提供可配置的精度级别,以精度换取速度。可以为每个操作控制精度;例如,请参阅 jax.lax.conv_general_dilated()jax.lax.dot() 文档字符串。但控制在未给定特定精度时操作获得的默认行为可能很有用。

此选项可用于控制 32 位输入上矩阵乘法和卷积中涉及的计算的默认精度级别。这些级别大致描述了计算标量乘积的精度。“bfloat16” 选项是最快且精度最低的;“float32” 类似于完整 float32 精度;“tensorfloat32” 是中间级别。

此参数还可用于为执行矩阵乘法的函数(如 jax.lax.dot())指定累积“算法”。要指定算法,请将此选项设置为 DotAlgorithmPreset 的名称。

参数:

new_val (Any)