jax.Array.view#
- abstract Array.view(dtype=None, type=None)[源]#
返回数组的按位副本,并将其视为新的数据类型。
这是
jax.lax.bitcast_convert_type()
的功能更丰富的包装器。如果源和目标数据类型具有相同的位宽,则结果与输入数组具有相同的形状。如果目标数据类型的位宽与源不同,则结果的最后一个轴的大小将相应调整。
>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape (1, 2, 6) >>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape (1, 2, 2)
涉及布尔值的转换并非在所有情况下都明确定义。关于如上所述的结果形状,布尔值被视为具有 8 位宽。然而,在转换为布尔数组时,输入应只包含 0 或 1 字节。否则,结果可能不可预测,或者可能根据结果的使用方式而改变。
此转换是保证且安全的
>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) Array([ True, False, True], dtype=bool)
然而,对于涉及此类视图的任何表达式(例如:jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_))的结果,不作任何保证。特别是,结果可能会在 JAX 版本之间以及根据平台而改变。要安全地将此类数组转换为布尔数组,请将其与 0 进行比较
>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 Array([ True, True, False], dtype=bool)