jax.scipy.ndimage.map_coordinates#
- jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[源码]#
使用插值将输入数组映射到新坐标。
JAX 对
scipy.ndimage.map_coordinates()的实现。给定一个输入数组和一组坐标,此函数返回在这些坐标处输入数组的插值。值。
- 参数:
input (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 从中进行值插值的 N 维输入数组。
coordinates (Sequence[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]) – 指定要评估插值值的坐标的长度为 N 的数组序列。
order (int) –
插值顺序。JAX 支持以下
0:最近邻
1:线性
mode (str) – 根据给定的模式填充输入边界外的点。JAX 支持以下选项之一:
('constant', 'nearest', 'mirror', 'wrap', 'reflect')。请注意,JAX 中的'wrap'模式的行为类似于 SciPy 中的'grid-wrap'模式,而 JAX 中的'constant'模式的行为类似于 SciPy 中的'grid-constant'模式。这种差异是由 SciPy 中这些模式的一个先前错误引起的(scipy/scipy#2640),该错误首先在 JAX 中通过更改现有模式的行为来修复,后来在 SciPy 中通过添加新名称的模式来修复,而不是为了向后兼容性而修复现有的模式。默认为 ‘constant’。cval (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 如果
mode='constant',则用于输入边界外点的计算值。默认为 0.0。
- 返回:
在指定坐标处的插值。值。
示例
>>> input = jnp.arange(12.0).reshape(3, 4) >>> input Array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], dtype=float32) >>> coordinates = [jnp.array([0.5, 1.5]), ... jnp.array([1.5, 2.5])] >>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1) Array([3.5, 8.5], dtype=float32)
注意
由于 JAX 修复了一个悬而未决的错误,边界附近的插值与 scipy 函数不同;请参阅 jax-ml/jax#11097。此函数解释
mode参数,其方式与 SciPy 的文档一致,但与 SciPy 的实现方式不同。