jax.lax.convert_element_type#
- jax.lax.convert_element_type(operand, new_dtype)[source]#
逐元素类型转换。
此函数直接降低为 stablehlo.convert 操作,该操作执行从一种类型到另一种类型的逐元素转换,类似于 C++
static_cast
。- 参数:
operand (ArrayLike) – 要转换的数组或标量值。
new_dtype (DTypeLike | dtypes.ExtendedDType) – 类似 dtype 的对象(例如
numpy.dtype
、标量类型或表示目标 dtype 的有效 dtype 名称)。
- 返回值:
与
operand
形状相同,逐元素转换为new_dtype
的数组。- 返回类型:
注意
如果
new_dtype
是 64 位类型且未启用 x64 模式,则将使用适当的 32 位类型代替。如果输入是 JAX 数组,并且输入 dtype 和输出 dtype 匹配,则将返回未修改的输入数组。
另请参阅
jax.numpy.astype()
: NumPy 风格的 dtype 转换 API。jax.Array.astype()
: 作为数组方法的 dtype 转换。jax.lax.bitcast_convert_type()
: 将位直接转换为新的 dtype。