jax.lax.axis_size#

jax.lax.axis_size(axis_name)[source]#

返回映射轴 axis_name 的大小。

参数:

axis_name (AxisName) – 哈希 Python 对象,用于命名映射轴。

返回:

表示大小的整数。

返回类型:

int

例如,在有 8 个 XLA 设备可用的情况下

>>> from functools import partial
>>> from jax.experimental.shard_map import shard_map
>>> from jax.sharding import PartitionSpec as P
>>> mesh = jax.make_mesh((8,), 'i')
>>> @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
... def f(_):
...   return lax.axis_size('i')
...
>>> f(jnp.zeros(16))
Array(8, dtype=int32, weak_type=True)
>>> mesh = jax.make_mesh((4, 2), ('i', 'j'))
>>> @partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P())
... def f(_):
...   return lax.axis_size(('i', 'j'))
...
>>> f(jnp.zeros((16, 8)))
Array(8, dtype=int32, weak_type=True)