jax.numpy.meshgrid#
- jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')[源代码]#
根据 N 个一维向量构建 N 维网格数组。
JAX 实现
numpy.meshgrid()。- 参数:
- 返回:
长度为 N 的网格数组列表。
- 返回类型:
另请参阅
jax.numpy.indices(): 生成索引网格。jax.numpy.mgrid: 使用索引语法创建 meshgrid。jax.numpy.ogrid: 使用索引语法创建 open meshgrid。
示例
对于以下示例,我们将使用这些一维数组作为输入
>>> x = jnp.array([1, 2]) >>> y = jnp.array([10, 20, 30])
二维笛卡尔网格
>>> x_grid, y_grid = jnp.meshgrid(x, y) >>> print(x_grid) [[1 2] [1 2] [1 2]] >>> print(y_grid) [[10 10] [20 20] [30 30]]
二维稀疏笛卡尔网格
>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) >>> print(x_grid) [[1 2]] >>> print(y_grid) [[10] [20] [30]]
二维矩阵索引网格
>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') >>> print(x_grid) [[1 1 1] [2 2 2]] >>> print(y_grid) [[10 20 30] [10 20 30]]