jax.lax.broadcast_to_rank#

jax.lax.broadcast_to_rank(x, rank)[source]#

添加值为 1 的前导维度,使 x 的秩为 rank

参数:
  • x (ArrayLike)

  • rank (int)

返回类型:

Array