跳到主要内容
Ctrl+K
JAX  documentation - Home

开始使用

  • 安装
  • 快速入门
  • 教程
    • 快速入门
    • 核心概念
    • 即时编译
    • 自动向量化
    • 自动微分
    • 调试入门
    • 伪随机数
    • 使用 pytree
    • 并行编程入门
    • 有状态计算
    • 使用 JIT 的控制流和逻辑运算符
    • 高级自动微分
    • 外部回调
    • 使用 jax.checkpoint (jax.remat) 进行梯度检查点设置
    • JAX 内部原理:原语
    • JAX 内部原理:jaxpr 语言
  • 🔪 JAX - 锋利之处 🔪
  • 常见问题 (FAQ)

更多指南/资源

  • 用户指南
    • 如何在 JAX 中思考
    • 性能分析计算
    • 性能分析设备内存
    • 调试运行时值
      • 编译后的打印和断点
      • checkify 转换
      • JAX 调试标志
    • GPU 性能提示
    • 持久编译缓存
    • Pytree
    • 错误
    • 提前降低和编译
    • 导出和序列化
      • 导出和序列化分段计算
      • 形状多态性
      • 与 TensorFlow 互操作
    • 传输保护
    • Pallas:JAX 内核语言
      • Pallas 快速入门
      • 软件流水线
      • 网格和 BlockSpec
      • Pallas TPU
        • 使用 Pallas 编写 TPU 内核
        • 流水线
        • 矩阵乘法
        • 标量预取和块稀疏计算
        • Pallas 中面向 TPU 的分布式计算
      • Pallas:Mosaic GPU
        • 使用 Pallas 编写 Mosaic GPU 内核
      • Pallas 设计说明
        • Pallas 设计
        • Pallas 异步操作
      • Pallas 更新日志
    • 外部函数接口 (FFI)
    • 训练一个简单的神经网络,使用 tensorflow/datasets 数据加载
    • 训练一个简单的神经网络,使用 PyTorch 数据加载
    • 贝叶斯推断的自动批处理
  • 高级指南
    • 分布式数组和自动并行化
    • 显式分片(又名“类型分片”)
    • 使用 shard_map 的手动并行
    • 多主机和多进程环境
    • 分布式数据加载
    • 自动微分食谱
    • 自定义导数规则
    • 使用 jax.checkpoint (又名 jax.remat) 控制自动微分的保存值
    • JAX 中的广义卷积
    • XLA 编译器标志列表
  • 开发者 notes
    • 为 JAX 做贡献
    • 从源码构建
    • 调查回归
    • Autodidax:从零开始的 JAX 核心
    • Autodidax2,第 1 部分:再次从零开始的 JAX
    • JAX 增强提案 (JEP)
      • 263: JAX PRNG 设计
      • 2026: JAX 可转换函数的自定义 JVP/VJP 规则
      • 4008: 自定义 VJP 和 `nondiff_argnums` 更新
      • 4410: 全局暂存
      • 9263: 类型化键和可插拔 RNG
      • 9407: JAX 类型提升语义设计
      • 9419: Jax 和 Jaxlib 版本控制
      • 10657: JAX 中的序列副作用
      • 11830: jax.remat / jax.checkpoint 新实现
      • 12049: JAX 类型注解路线图
      • 14273: shard_map (shmap) 用于简单的按设备代码
      • 15856: jax.extend,一个扩展模块
      • 17111: shard_map(和其他映射)的有效转置
      • 18137: JAX NumPy 和 SciPy 包装器的范围
      • 25516: 基于努力的版本控制
  • 扩展指南
    • 在 JAX 中编写自定义 Jaxpr 解释器
    • jax.extend 模块
      • jax.extend.core 模块
      • jax.extend.linear_util 模块
      • jax.extend.mlir 模块
      • jax.extend.random 模块
    • 构建于 JAX 之上
  • 注释
    • API 兼容性
    • Python 和 NumPy 版本支持策略
    • 异步调度
    • 并发
    • GPU 内存分配
    • 秩提升警告
    • 类型提升语义
    • 默认数据类型和 X64 标志
  • 公共 API:jax 包
    • jax.numpy 模块
      • jax.numpy.fft.fft
      • jax.numpy.fft.fft2
      • jax.numpy.fft.fftfreq
      • jax.numpy.fft.fftn
      • jax.numpy.fft.fftshift
      • jax.numpy.fft.hfft
      • jax.numpy.fft.ifft
      • jax.numpy.fft.ifft2
      • jax.numpy.fft.ifftn
      • jax.numpy.fft.ifftshift
      • jax.numpy.fft.ihfft
      • jax.numpy.fft.irfft
      • jax.numpy.fft.irfft2
      • jax.numpy.fft.irfftn
      • jax.numpy.fft.rfft
      • jax.numpy.fft.rfft2
      • jax.numpy.fft.rfftfreq
      • jax.numpy.fft.rfftn
    • jax.scipy 模块
      • jax.scipy.stats.bernoulli.logpmf
      • jax.scipy.stats.bernoulli.pmf
      • jax.scipy.stats.bernoulli.cdf
      • jax.scipy.stats.bernoulli.ppf
    • jax.lax 模块
    • jax.random 模块
    • jax.sharding 模块
    • jax.debug 模块
    • jax.dlpack 模块
    • jax.distributed 模块
    • jax.dtypes 模块
    • jax.ffi 模块
    • jax.flatten_util 模块
    • jax.image 模块
    • jax.nn 模块
      • jax.nn.initializers 模块
    • jax.ops 模块
    • jax.profiler 模块
    • jax.stages 模块
    • jax.test_util 模块
    • jax.tree 模块
    • jax.tree_util 模块
    • jax.typing 模块
    • jax.export 模块
    • jax.extend 模块
      • jax.extend.core 模块
      • jax.extend.linear_util 模块
      • jax.extend.mlir 模块
      • jax.extend.random 模块
    • jax.example_libraries 模块
      • jax.example_libraries.optimizers 模块
      • jax.example_libraries.stax 模块
    • jax.experimental 模块
      • jax.experimental.checkify 模块
      • jax.experimental.compilation_cache 模块
      • jax.experimental.custom_dce 模块
      • jax.experimental.custom_partitioning 模块
      • jax.experimental.jet 模块
      • jax.experimental.key_reuse 模块
      • jax.experimental.mesh_utils 模块
      • jax.experimental.multihost_utils 模块
      • jax.experimental.pallas 模块
        • jax.experimental.pallas.mosaic_gpu 模块
        • jax.experimental.pallas.triton 模块
        • jax.experimental.pallas.tpu 模块
      • jax.experimental.pjit 模块
      • jax.experimental.serialize_executable 模块
      • jax.experimental.shard_map 模块
      • jax.experimental.sparse 模块
        • jax.experimental.sparse.BCOO
        • jax.experimental.sparse.bcoo_broadcast_in_dim
        • jax.experimental.sparse.bcoo_concatenate
        • jax.experimental.sparse.bcoo_dot_general
        • jax.experimental.sparse.bcoo_dot_general_sampled
        • jax.experimental.sparse.bcoo_dynamic_slice
        • jax.experimental.sparse.bcoo_extract
        • jax.experimental.sparse.bcoo_fromdense
        • jax.experimental.sparse.bcoo_gather
        • jax.experimental.sparse.bcoo_multiply_dense
        • jax.experimental.sparse.bcoo_multiply_sparse
        • jax.experimental.sparse.bcoo_update_layout
        • jax.experimental.sparse.bcoo_reduce_sum
        • jax.experimental.sparse.bcoo_reshape
        • jax.experimental.sparse.bcoo_slice
        • jax.experimental.sparse.bcoo_sort_indices
        • jax.experimental.sparse.bcoo_squeeze
        • jax.experimental.sparse.bcoo_sum_duplicates
        • jax.experimental.sparse.bcoo_todense
        • jax.experimental.sparse.bcoo_transpose
    • jax.lib 模块
    • jax.Array.addressable_shards
    • jax.Array.all
    • jax.Array.any
    • jax.Array.argmax
    • jax.Array.argmin
    • jax.Array.argpartition
    • jax.Array.argsort
    • jax.Array.astype
    • jax.Array.at
    • jax.Array.choose
    • jax.Array.clip
    • jax.Array.compress
    • jax.Array.committed
    • jax.Array.conj
    • jax.Array.conjugate
    • jax.Array.copy
    • jax.Array.copy_to_host_async
    • jax.Array.cumprod
    • jax.Array.cumsum
    • jax.Array.device
    • jax.Array.diagonal
    • jax.Array.dot
    • jax.Array.dtype
    • jax.Array.flat
    • jax.Array.flatten
    • jax.Array.global_shards
    • jax.Array.imag
    • jax.Array.is_fully_addressable
    • jax.Array.is_fully_replicated
    • jax.Array.item
    • jax.Array.itemsize
    • jax.Array.max
    • jax.Array.mean
    • jax.Array.min
    • jax.Array.nbytes
    • jax.Array.ndim
    • jax.Array.nonzero
    • jax.Array.prod
    • jax.Array.ptp
    • jax.Array.ravel
    • jax.Array.real
    • jax.Array.repeat
    • jax.Array.reshape
    • jax.Array.round
    • jax.Array.searchsorted
    • jax.Array.shape
    • jax.Array.sharding
    • jax.Array.size
    • jax.Array.sort
    • jax.Array.squeeze
    • jax.Array.std
    • jax.Array.sum
    • jax.Array.swapaxes
    • jax.Array.take
    • jax.Array.to_device
    • jax.Array.trace
    • jax.Array.transpose
    • jax.Array.var
    • jax.Array.view
    • jax.Array.T
    • jax.Array.mT
  • 关于项目
  • 更新日志
  • 术语表
  • 配置选项
  • 扩展指南
  • jax.extend.core 模块
  • jax.extend.c...
  • .rst

jax.extend.core.Token

目录

  • Token
    • Token.__init__()

jax.extend.core.Token#

class jax.extend.core.Token(buf)[source]#
__init__(buf)[source]#

方法

__init__(buf)

block_until_ready()

previous

jax.extend.core.Primitive

next

jax.extend.core.Var

目录
  • Token
    • Token.__init__()

作者:JAX 作者

© 版权所有 2024, The JAX Authors.