jax.ShapeDtypeStruct#

class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False, vma=None, is_ref=False)[来源]#

一个用于存放数组的 shape、dtype 和其他静态属性的容器。

ShapeDtypeStruct 通常与 jax.eval_shape() 结合使用。

参数:
  • shape – 一个表示数组形状的整数序列

  • dtype – 一个类似 dtype 的对象

  • sharding – (可选) 一个 jax.Sharding 对象

__init__(shape, dtype, *, sharding=None, weak_type=False, vma=None, is_ref=False)[来源]#

方法

__init__(shape, dtype, *[, sharding, ...])

update(**kwargs)

属性

shape

dtype

weak_type

vma

is_ref

format

ndim

sharding

size