jax.Array.view#

abstract Array.view(dtype=None, type=None)[源代码]#

返回数组的按位副本,并将其视为新的数据类型。

这是 jax.lax.bitcast_convert_type() 的一个功能更全面的封装。

如果源数据类型和目标数据类型的位数相同,则结果的形状与输入数组相同。如果目标数据类型的位数与源数据类型不同,则结果的最后一个轴的大小将相应调整。

>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape
(1, 2, 6)
>>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape
(1, 2, 2)

涉及布尔值的转换在所有情况下都不明确。关于结果形状如上所述,布尔值被视为具有 8 位宽度。但是,在转换为布尔数组时,输入应只包含 0 或 1 字节。否则,结果可能不可预测或可能因结果的使用方式而异。

此转换是保证且安全的

>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_)
Array([ True, False,  True], dtype=bool)

但是,不能保证涉及此类视图的任何表达式的结果: jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)。特别是,结果可能在 JAX 版本之间以及根据平台而变化。要安全地将此类数组转换为布尔数组,请将其与 0 进行比较。

>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0
Array([ True,  True, False], dtype=bool)
参数:
  • dtype (DTypeLike | None) – 可选的输出数据类型。如果未指定,则输出数据类型与输入数据类型相同。

  • type (None) – 未实现;接受是为了兼容 NumPy。

  • self (Array)

返回:

数组,被视为新的数据类型。与 NumPy 不同,该数组可能是输入数组的副本,也可能不是。

返回类型:

Array