jax.ShapeDtypeStruct#

jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False, vma=None, is_ref=False)[源代码]#

一个用于存储数组的形状、数据类型和其他静态属性的容器。

ShapeDtypeStruct 通常与 jax.eval_shape() 结合使用。

参数:
  • shape – 表示数组形状的整数序列

  • dtype – 类似数据类型的对象

  • sharding – (可选)一个 jax.Sharding 对象

__init__(shape, dtype, *, sharding=None, weak_type=False, vma=None, is_ref=False)[源代码]#

方法

__init__(shape, dtype, *[, sharding, ...])

update(**kwargs)

属性

shape

dtype

sharding

weak_type

vma

is_ref

format

ndim

size