公共API:jax 包#
子包#
jax.numpy模块jax.scipy模块jax.lax模块jax.random模块jax.sharding模块jax.debug模块jax.dlpack模块jax.distributed模块jax.dtypes模块jax.ffi模块jax.flatten_util模块jax.image模块jax.nn模块jax.ops模块jax.profiler模块jax.ref模块jax.stages模块jax.test_util模块jax.tree模块jax.tree_util模块jax.typing模块jax.export模块jax.extend模块jax.example_libraries模块jax.experimental模块
配置#
用于 jax_check_tracer_leaks 配置选项的上下文管理器。 |
|
用于 jax_check_tracer_leaks 配置选项的上下文管理器。 |
|
用于jax_debug_nans配置选项的上下文管理器。 |
|
用于jax_debug_infs配置选项的上下文管理器。 |
|
用于jax_default_device配置选项的上下文管理器。 |
|
用于jax_default_matmul_precision配置选项的上下文管理器。 |
|
用于jax_default_prng_impl配置选项的上下文管理器。 |
|
用于jax_enable_checks配置选项的上下文管理器。 |
|
用于jax_enable_custom_prng配置选项的上下文管理器(瞬时)。 |
|
用于jax_enable_custom_vjp_by_custom_transpose配置选项的上下文管理器(瞬时)。 |
|
用于jax_enable_x64配置选项的上下文管理器。 |
|
用于jax_log_compiles配置选项的上下文管理器。 |
|
用于jax_no_tracing配置选项的上下文管理器。 |
|
用于jax_numpy_rank_promotion配置选项的上下文管理器。 |
|
|
一个上下文管理器,用于控制所有传输的传输保护级别。 |
即时编译(jit)#
|
为XLA的即时编译设置 |
|
禁用其动态上下文下 |
确保在跟踪/编译时进行评估(或出错)的上下文管理器。 |
|
|
给定示例参数,创建一个返回 |
|
在不进行任何FLOPs的情况下计算 |
|
一个用于存储数组的形状、dtype和其他静态属性的容器。 |
|
将 |
|
将 |
返回默认XLA后端平台的名称。 |
|
|
在暂存JAX计算时,向函数添加用户指定的名称。 |
|
一个上下文管理器,将用户指定的名称添加到JAX名称堆栈中。 |
尝试调用pytree叶子上的 |
|
尝试调用pytree叶子上的 |
|
|
使用指定的形状和轴名称创建高效的网格。 |
自动微分#
|
创建一个评估 |
|
创建一个同时评估 |
|
别名为 |
|
使用前向模式AD逐列计算的 |
|
使用反向模式AD逐行计算的 |
|
作为密集数组的 |
|
计算 |
使用 |
|
|
转置一个保证是线性的函数。 |
|
计算 |
|
用于定义自定义VJP规则(也称为自定义梯度)的便利函数。 |
|
闭包转换实用工具,用于高阶自定义导数。 |
|
使 |
矢量化#
|
矢量化映射。 |
|
定义一个支持广播的向量化函数。 |
并行化#
|
使用设备网格对数据分片进行映射。 |
|
单轴分片映射,一次映射一个轴的函数f。 |
|
支持集体操作的并行映射。 |
|
返回给定后端的所有设备列表。 |
|
与 |
|
返回此进程的整数进程索引。 |
|
返回设备总数。 |
|
返回此进程可寻址的设备数量。 |
|
返回与后端关联的JAX进程数量。 |
|
返回与后端关联的所有JAX进程索引列表。 |
定制#
custom_jvp#
|
为自定义JVP规则定义设置一个JAX可转换函数。 |
|
为此实例表示的函数定义自定义 JVP 规则。 |
|
方便地为每个参数分别定义 JVP 的封装器。 |
custom_vjp#
|
为自定义VJP规则定义设置一个JAX可转换函数。 |
|
为此实例表示的函数定义自定义 VJP 规则。 |
custom_batching#
自定义JAX可转换函数的vmap行为。 |
|
|
为此 `custom_vmap` 函数定义 `vmap` 规则。 |
使用循环的 |
jax.Array(jax.Array)#
|
JAX的数组基类 |
|
通过从 |
|
从每个设备上的 |
|
使用进程中可用的数据创建分布式张量。 |
数组属性和方法#
可寻址分片列表。 |
|
|
测试沿给定轴的所有数组元素是否评估为 True。 |
|
测试给定轴上的任何数组元素是否评估为 True。 |
|
返回最大值的索引。 |
|
返回最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回对数组进行排序的索引。 |
|
复制数组并转换为指定的数据类型。 |
用于索引更新功能的辅助属性。 |
|
|
从多个数组的元素中选择构造一个数组。 |
|
返回一个值被限制在指定范围内的数组。 |
|
沿给定轴返回此数组的选定切片。 |
数组是否已提交。 |
|
返回数组的复共轭。 |
|
返回数组的复共轭。 |
|
返回数组的副本。 |
|
将 |
|
|
返回数组的累积乘积。 |
|
返回数组的累积和。 |
兼容 Array API 的设备属性。 |
|
|
从数组中返回指定的对角线。 |
|
计算两个数组的点积。 |
数组的数据类型( |
|
请使用 |
|
|
将数组展平为一维形状。 |
全局分片列表。 |
|
返回数组的虚部。 |
|
此 Array 是否完全可寻址? |
|
此 Array 是否完全复制? |
|
|
将数组的一个元素复制到标准 Python 标量并返回它。 |
一个数组元素的字节长度。 |
|
|
返回给定轴上数组元素的最大值。 |
|
返回给定轴上数组元素的平均值。 |
|
返回给定轴上数组元素的最小值。 |
数组元素占用的总字节数。 |
|
数组的维度数。 |
|
|
返回数组中非零元素的索引。 |
|
返回给定轴上数组元素的乘积。 |
|
返回沿给定轴的峰峰值范围。 |
|
将数组展平为一维形状。 |
返回数组的实部。 |
|
|
从重复元素构造数组。 |
|
返回一个包含相同数据但具有新形状的数组。 |
|
将数组元素四舍五入到给定的小数位数。 |
|
在已排序数组中执行二分查找。 |
数组的形状。 |
|
此数组的分片信息。 |
|
数组中的元素总数。 |
|
|
返回数组的排序副本。 |
|
从数组中移除一个或多个长度为 1 的轴。 |
|
计算沿给定轴的标准差。 |
|
在给定轴上对数组元素求和。 |
|
交换数组的两个轴。 |
|
从数组中选取元素。 |
|
在指定设备上返回数组的副本 |
|
返回沿对角线的和。 |
|
返回数组的一个副本,其中轴已转置。 |
|
计算给定轴上的方差。 |
|
返回数组的按位副本,并将其视为新的数据类型。 |
计算所有轴的数组转置。 |
|
计算(批处理)矩阵转置。 |
回调#
|
调用纯Python回调。 |
|
调用非纯Python回调。 |
|
调用一个可分阶段的 Python 回调。 |
|
打印值并在分阶段的 JAX 函数中工作。 |
杂项#
可用设备的描述符。 |
|
|
返回一个包含本地环境和JAX安装信息的字符串。 |
|
返回platform后端的所有活动数组。 |
清除所有编译和暂存缓存。 |
检查点策略#
默认策略,如同未使用 |
|
重新计算所有内容,如同未使用自定义策略一样。 |
|
|
|
|
|
这是Transformer的一个有用启发式方法。 |
|
这是Transformer的一个有用启发式方法。 |
|
仅保存命名值,即checkpoint_name的任何输出,排除给定的名称。 |
|
仅保存命名值,并且仅限于给定的名称。 |
|
与 |
|
|
与 |
给定策略的逻辑OR。 |