jax.numpy.mgrid#
- jax.numpy.mgrid = <jax._src.numpy.index_tricks._Mgrid object>#
返回密集的(dense)多维“meshgrid”。
LAX 后端实现的
numpy.mgrid
。这是一个便捷的封装,用于实现jax.numpy.meshgrid()
提供的功能,并指定参数sparse=False
。另请参阅
jnp.ogrid: 是 jnp.mgrid 的开放/稀疏版本
示例
传入
[start:stop:step]
可生成与jax.numpy.arange()
类似的值>>> jnp.mgrid[0:4:1] Array([0, 1, 2, 3], dtype=int32)
传入虚数步长可生成与
jax.numpy.linspace()
类似的值>>> jnp.mgrid[0:1:4j] Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
可以使用多个切片来创建广播的索引网格
>>> jnp.mgrid[:2, :3] Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32)