jax.numpy.astype#
- jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[源代码]#
将数组转换为指定的 dtype。
JAX 对
numpy.astype()
的实现。这是通过
jax.lax.convert_element_type()
实现的,在某些情况下,其行为可能与numpy.astype()
略有不同。特别是,浮点数到整数和整数到浮点数的转换细节取决于具体的实现。- 参数:
- 返回:
一个与
x
形状相同的数组,包含指定 dtype 的值。- 返回类型:
另请参阅
jax.lax.convert_element_type()
:用于 XLA 样式 dtype 转换的较低级别函数。
示例
>>> x = jnp.array([0, 1, 2, 3]) >>> x Array([0, 1, 2, 3], dtype=int32) >>> x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0]) >>> y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)