jax.numpy.mgrid#

jax.numpy.mgrid = <jax._src.numpy.index_tricks._Mgrid object>#

返回密集的(dense)多维“meshgrid”。

LAX 后端实现的 numpy.mgrid。这是一个便捷的封装,用于实现 jax.numpy.meshgrid() 提供的功能,并指定参数 sparse=False

另请参阅

jnp.ogrid: 是 jnp.mgrid 的开放/稀疏版本

示例

传入 [start:stop:step] 可生成与 jax.numpy.arange() 类似的值

>>> jnp.mgrid[0:4:1]
Array([0, 1, 2, 3], dtype=int32)

传入虚数步长可生成与 jax.numpy.linspace() 类似的值

>>> jnp.mgrid[0:1:4j]
Array([0.        , 0.33333334, 0.6666667 , 1.        ], dtype=float32)

可以使用多个切片来创建广播的索引网格

>>> jnp.mgrid[:2, :3]
Array([[[0, 0, 0],
        [1, 1, 1]],
       [[0, 1, 2],
        [0, 1, 2]]], dtype=int32)