jax.lax.convert_element_type#

jax.lax.convert_element_type(operand, new_dtype)[source]#

逐元素类型转换。

此函数直接降低为 stablehlo.convert 操作,该操作执行从一种类型到另一种类型的逐元素转换,类似于 C++ static_cast

参数:
  • operand (ArrayLike) – 要转换的数组或标量值。

  • new_dtype (DTypeLike | dtypes.ExtendedDType) – 类似 dtype 的对象(例如 numpy.dtype、标量类型或表示目标 dtype 的有效 dtype 名称)。

返回值:

operand 形状相同,逐元素转换为 new_dtype 的数组。

返回类型:

Array

注意

如果 new_dtype 是 64 位类型且未启用 x64 模式,则将使用适当的 32 位类型代替。

如果输入是 JAX 数组,并且输入 dtype 和输出 dtype 匹配,则将返回未修改的输入数组。

另请参阅