jax.numpy.left_shift#
- jax.numpy.left_shift(x, y, /)[源码]#
将
x的位向左移动y指定的量,逐元素执行。JAX 对
numpy.left_shift的实现。- 参数:
x (ArrayLike) – 输入数组,必须是整数类型。
y (ArrayLike) – 将
x中的每个元素向左移动的位数,只接受整数子类型。x和y的形状必须相同或可广播兼容。
- 返回:
一个数组,包含
x按照y指定的量左移后的元素,其形状与x和y广播后的形状相同。- 返回类型:
注意
将
x左移y位,在涉及的数据类型范围内,等价于x * (2**y)。另请参阅
jax.numpy.right_shift():和jax.numpy.bitwise_right_shift():将x1的位向右移动x2指定的量,逐元素执行。jax.numpy.bitwise_left_shift():是jax.left_shift()的别名。
示例
>>> def print_binary(x): ... return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5) >>> x1 Array([0, 1, 2, 3, 4], dtype=int32) >>> print_binary(x1) ['0b0', '0b1', '0b10', '0b11', '0b100'] >>> x2 = 1 >>> result = jnp.left_shift(x1, x2) >>> result Array([0, 2, 4, 6, 8], dtype=int32) >>> print_binary(result) ['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4 >>> print_binary([x3]) ['0b100'] >>> x4 = jnp.array([1, 2, 3, 4]) >>> result1 = jnp.left_shift(x3, x4) >>> result1 Array([ 8, 16, 32, 64], dtype=int32) >>> print_binary(result1) ['0b1000', '0b10000', '0b100000', '0b1000000']