jax.lax.bitcast_convert_type#
- jax.lax.bitcast_convert_type(operand, new_dtype)[源代码]#
逐元素的位转换。
此函数直接降级到 stablehlo.bitcast_convert 操作。
输出形状取决于输入和输出 dtypes 的大小,逻辑如下
if new_dtype.itemsize == operand.dtype.itemsize: output_shape = operand.shape if new_dtype.itemsize < operand.dtype.itemsize: output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize) if new_dtype.itemsize > operand.dtype.itemsize: assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize output_shape = operand.shape[:-1]
- 参数:
operand (ArrayLike) – 要转换的数组或标量值
new_dtype (DTypeLike) – 新类型。 应该是 NumPy 类型。
- 返回:
形状为 output_shape(见上文)且类型为 new_dtype 的数组,由与 operand 相同的位构成。
- 返回类型:
另请参阅
jax.lax.convert_element_type()
:值保持 dtype 转换。jax.Array.view()
:用于位转换类型转换的 NumPy 风格的 API。