资源和高级指南# 本节包含有关更高级主题的示例和教程,例如多核计算、自动微分和自定义操作。 并行计算 分布式数组和自动并行化 显式分片(亦称“类型内分片”) 使用 shard_map 进行手动并行化 设备本地数组布局控制 JAX 内存和主机卸载 优化器状态卸载 多控制器 JAX 简介(即多进程/多主机 JAX) 分布式数据加载 同地 Python 机器学习 训练食谱 自动微分 自动微分秘籍 自定义导数规则 使用 jax.checkpoint(亦称 jax.remat)控制自动微分的保存值 高级自动微分 错误和调试 错误 调试简介 调试运行时值 JAX 调试标志 传输守卫 Pytrees Pytrees 性能优化 持久化编译缓存 GPU 性能提示 性能基准测试和剖析 计算性能分析 设备内存性能分析 非函数式编程 Ref:用于数据管道和内存控制的可变数组 外部回调 外部回调 FFI 外部函数接口 (FFI) 建模工作流 使用 jax.checkpoint (jax.remat) 进行梯度检查点 提前降低和编译 导出和序列化 Pallas Pallas:一种 JAX 内核语言 示例应用程序 使用 tensorflow/datasets 数据加载训练简单神经网络 使用 PyTorch 数据加载训练简单神经网络 贝叶斯推理的自动批处理 深入研究 JAX 中的广义卷积 XLA 编译器标志 JAX 内部:原语 JAX 内部:jaxpr 语言