jax.lax.convert_element_type#
- jax.lax.convert_element_type(operand, new_dtype)[source]#
逐元素转换。
此函数直接降级到 stablehlo.convert 操作,该操作执行从一种类型到另一种类型的逐元素转换,类似于 C++
static_cast
。- 参数:
operand (ArrayLike) – 要转换的数组或标量值。
new_dtype (DTypeLike | dtypes.ExtendedDType) – 一个类似 dtype 的对象(例如,
numpy.dtype
、标量类型或有效的 dtype 名称),表示目标 dtype。
- 返回:
与
operand
具有相同形状的数组,逐元素转换为new_dtype
。- 返回类型:
注意
如果
new_dtype
是 64 位类型并且未启用 x64 模式,则将使用适当的 32 位类型代替。如果输入是 JAX 数组并且输入 dtype 和输出 dtype 匹配,则将返回未经修改的输入数组。
另请参阅
jax.numpy.astype()
: NumPy 风格的 dtype 转换 API。jax.Array.astype()
: 作为数组方法的 dtype 转换。jax.lax.bitcast_convert_type()
: 将位直接转换为新的 dtype。