jax.profiler 模块#

跟踪和时间分析#

计算分析 介绍了如何利用 JAX 的跟踪和时间分析功能。

start_server(port)

在端口 port 上启动性能分析器服务器。

start_trace(log_dir[, create_perfetto_link, ...])

启动性能分析器跟踪。

stop_trace()

停止当前正在运行的性能分析器跟踪。

trace(log_dir[, create_perfetto_link, ...])

用于获取性能分析器跟踪的上下文管理器。

annotate_function(func[, name])

一个装饰器,用于为函数执行生成跟踪事件。

TraceAnnotation(*args, **kwargs)

一个上下文管理器,用于在性能分析器中生成跟踪事件。

StepTraceAnnotation(name, **kwargs)

一个上下文管理器,用于在性能分析器中生成步骤跟踪事件。

设备内存分析#

请参阅 设备内存分析,了解 JAX 设备内存分析功能的介绍。

device_memory_profile([backend])

将 JAX 设备内存配置文件捕获为 pprof-格式协议缓冲区。

save_device_memory_profile(filename[, backend])

收集设备内存配置文件并将其写入文件。