jax.numpy.multiply#

jax.numpy.multiply = <jnp.ufunc 'multiply'>#

将两个数组进行逐元素相乘。

JAX 对 numpy.multiply 的实现。这是一个通用函数,支持在 jax.numpy.ufunc 中描述的其他 API。此函数为 JAX 数组提供了 * 运算符的实现。

参数:
  • x – 要相乘的数组。必须能够广播到相同的形状。

  • y – 要相乘的数组。必须能够广播到相同的形状。

  • args (ArrayLike)

  • out (None)

  • where (None)

返回:

包含逐元素相乘结果的数组。

返回类型:

任意类型

示例

显式调用 multiply

>>> x = jnp.arange(4)
>>> jnp.multiply(x, 10)
Array([ 0, 10, 20, 30], dtype=int32)

通过 * 运算符调用 multiply

>>> x * 10
Array([ 0, 10, 20, 30], dtype=int32)