jax.profiler.StepTraceAnnotation#
- class jax.profiler.StepTraceAnnotation(name, **kwargs)[源代码]#
用于在 profiler 中生成步骤跟踪事件的上下文管理器。
步骤跟踪事件跨越了上下文中包含的代码的持续时间。Profiler 将为每个步骤跟踪事件提供性能分析。
例如,它可用于标记训练步骤,并使 profiler 能够按步骤提供性能分析。
>>> while global_step < NUM_STEPS: ... with jax.profiler.StepTraceAnnotation("train", step_num=global_step): ... train_step() ... global_step += 1
如果进程正在被 TensorBoard 跟踪,这将导致“train xx”事件显示在跟踪时间线上。此外,如果使用加速器,设备跟踪时间线也将显示“train xx”事件。请注意,“step_num”可以作为关键字参数设置,将全局步数传递给 profiler。
- 参数:
name (str)
方法
__init__(self, arg0, /, **kwargs)属性
is_enabledset_metadata