jax.numpy.left_shift#

jax.numpy.left_shift(x, y, /)[source]#

逐元素将 x 的位向左移动 y 指定的位数。

JAX 对 numpy.left_shift 的实现。

参数:
  • x (ArrayLike) – 输入数组,必须为整数类型。

  • y (ArrayLike) – 将 x 中每个元素向左移动的位数,只接受整数子类型。xy 必须形状相同或兼容广播。

返回:

一个数组,包含将 x 中元素按 y 指定的位数左移后的结果,其形状与 xy 广播后的形状相同。

返回类型:

数组

注意

在涉及的数据类型范围内,将 x 左移 y 位等效于 x * (2**y)

另请参阅

示例

>>> def print_binary(x):
...   return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5)
>>> x1
Array([0, 1, 2, 3, 4], dtype=int32)
>>> print_binary(x1)
['0b0', '0b1', '0b10', '0b11', '0b100']
>>> x2 = 1
>>> result = jnp.left_shift(x1, x2)
>>> result
Array([0, 2, 4, 6, 8], dtype=int32)
>>> print_binary(result)
['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4
>>> print_binary([x3])
['0b100']
>>> x4 = jnp.array([1, 2, 3, 4])
>>> result1 = jnp.left_shift(x3, x4)
>>> result1
Array([ 8, 16, 32, 64], dtype=int32)
>>> print_binary(result1)
['0b1000', '0b10000', '0b100000', '0b1000000']