jax.lax.axis_size#
- jax.lax.axis_size(axis_name)[source]#
返回映射轴
axis_name
的大小。- 参数:
axis_name (AxisName) – 哈希 Python 对象,用于命名映射轴。
- 返回:
表示大小的整数。
- 返回类型:
例如,在有 8 个 XLA 设备可用的情况下
>>> from functools import partial >>> from jax.experimental.shard_map import shard_map >>> from jax.sharding import PartitionSpec as P >>> mesh = jax.make_mesh((8,), 'i') >>> @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) ... def f(_): ... return lax.axis_size('i') ... >>> f(jnp.zeros(16)) Array(8, dtype=int32, weak_type=True) >>> mesh = jax.make_mesh((4, 2), ('i', 'j')) >>> @partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P()) ... def f(_): ... return lax.axis_size(('i', 'j')) ... >>> f(jnp.zeros((16, 8))) Array(8, dtype=int32, weak_type=True)