jax.ShapeDtypeStruct#
- class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False, vma=None, is_ref=False)[来源]#
一个用于存放数组的 shape、dtype 和其他静态属性的容器。
ShapeDtypeStruct
通常与jax.eval_shape()
结合使用。- 参数:
shape – 一个表示数组形状的整数序列
dtype – 一个类似 dtype 的对象
sharding – (可选) 一个
jax.Sharding
对象
方法
__init__
(shape, dtype, *[, sharding, ...])update
(**kwargs)属性
shape
dtype
weak_type
vma
is_ref
format
ndim
sharding
size