jax.lax.convert_element_type#

jax.lax.convert_element_type(operand, new_dtype)[source]#

逐元素转换。

此函数直接降级到 stablehlo.convert 操作,该操作执行从一种类型到另一种类型的逐元素转换,类似于 C++ static_cast

参数:
  • operand (ArrayLike) – 要转换的数组或标量值。

  • new_dtype (DTypeLike | dtypes.ExtendedDType) – 一个类似 dtype 的对象(例如,numpy.dtype、标量类型或有效的 dtype 名称),表示目标 dtype。

返回:

operand 具有相同形状的数组,逐元素转换为 new_dtype

返回类型:

Array

注意

如果 new_dtype 是 64 位类型并且未启用 x64 模式,则将使用适当的 32 位类型代替。

如果输入是 JAX 数组并且输入 dtype 和输出 dtype 匹配,则将返回未经修改的输入数组。

另请参阅