jax.scipy.ndimage.map_coordinates#
- jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[source]#
使用插值将输入数组映射到新坐标。
JAX 对
scipy.ndimage.map_coordinates()
的实现给定一个输入数组和一组坐标,此函数返回输入数组在这些坐标处的插值。
- 参数:
input (Array | ndarray | bool | number | bool | int | float | complex) – 用于插值的N维输入数组。
coordinates (Sequence[Array | ndarray | bool | number | bool | int | float | complex]) – 长度为 N 的数组序列,指定用于评估插值值的坐标。
order (int) –
插值阶数。JAX 支持以下类型:
0: 最近邻
1: 线性
mode (str) – 输入边界外的点根据给定模式填充。JAX 支持以下模式之一:
('constant', 'nearest', 'mirror', 'wrap', 'reflect')
。请注意,'wrap'
模式在 JAX 中的行为与 SciPy 中的'grid-wrap'
模式一致,并且'constant'
模式在 JAX 中的行为与 SciPy 中的'grid-constant'
模式一致。这种差异是由 SciPy 中这些模式的一个早期 bug 引起的 (scipy/scipy#2640),该 bug 最早由 JAX 通过改变现有模式的行为来修复,后来 SciPy 也通过添加新名称的模式而不是修复现有模式来修复,以实现向后兼容性。默认值为 'constant'。cval (Array | ndarray | bool | number | bool | int | float | complex) – 如果
mode='constant'
,则用于输入边界外的值。默认值为 0.0。
- 返回:
在指定坐标处的插值。
示例
>>> input = jnp.arange(12.0).reshape(3, 4) >>> input Array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], dtype=float32) >>> coordinates = [jnp.array([0.5, 1.5]), ... jnp.array([1.5, 2.5])] >>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1) Array([3.5, 8.5], dtype=float32)
注意
边界附近的插值与 SciPy 函数不同,因为 JAX 修复了一个未解决的 bug;请参见 jax-ml/jax#11097。此函数将
mode
参数解释为 SciPy 文档中描述的那样,而不是 SciPy 实际实现的那样。