jax.default_matmul_precision#

jax.default_matmul_precision = <jax._src.config.State object>#

jax_default_matmul_precision 配置选项的上下文管理器。

控制32位输入的默认矩阵乘法和卷积精度。

某些平台(例如TPU)为矩阵乘法和卷积计算提供可配置的精度级别,以速度换取精度。每个操作的精度都可以控制;例如,请参阅jax.lax.conv_general_dilated()jax.lax.dot()的文档字符串。但是,当未指定操作的精度时,控制其默认行为会很有用。

此选项可用于控制涉及32位输入矩阵乘法和卷积计算的默认精度级别。这些级别大致描述了计算标量积时的精度。“bfloat16”选项最快且精度最低;“float32”类似于完整的float32精度;“tensorfloat32”介于两者之间。

此参数还可用于为执行矩阵乘法的函数(例如jax.lax.dot())指定累加“算法”。要指定算法,请将此选项设置为DotAlgorithmPreset的名称。

参数:

new_val (任意类型)