最近用到多维插值,本来打算手撸,后来查了一下有现成的包,就直接拿过来用了,源码中的参数介绍不是很清楚,重新记录下。
def interpn(points, values, xi, method="linear", bounds_error=True,
fill_value=np.nan):
"""
Multidimensional interpolation on regular or rectilinear grids.
Strictly speaking, not all regular grids are supported - this function
works on *rectilinear* grids, that is, a rectangular grid with even or
uneven spacing.
Parameters
----------
points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, )
The points defining the regular grid in n dimensions. The points in
each dimension (i.e. every elements of the points tuple) must be
strictly ascending or descending.
这里传的是已知数据(也就是下面的values)的每一维的自变量值,比如总共有两维xy,x为[1,2,3],y为[10,20,30,40],那这里points为([1,2,3],[10,20,30,40])
values : array_like, shape (m1, ..., mn, ...)
The data on the regular grid in n dimensions. Complex data can be
acceptable.
values,就是每个自变量对应的值,用上面的例子就是f(x,y),那么values应该为3*4的np数组
xi : ndarray of shape (..., ndim)
The coordinates to sample the gridded data at
xi即我们要预测的点,用上面的例子就是[x_new,y_new]
method : str, optional
The method of interpolation to perform. Supported are "linear" and
"nearest", and "splinef2d". "splinef2d" is only supported for
2-dimensional data.
bounds_error : bool, optional
If True, when interpolated values are requested outside of the
domain of the input data, a ValueError is raised.
If False, then `fill_value` is used.
fill_value : number, optional
If provided, the value to use for points outside of the
interpolation domain. If None, values outside
the domain are extrapolated. Extrapolation is not supported by method
"splinef2d".
Returns
-------
values_x : ndarray, shape xi.shape[:-1] + values.shape[ndim:]
Interpolated values at input coordinates.
一个简单的例子
import numpy as np
from scipy.interpolate import interpn
class Hk:
def __init__(self):
self.aa = np.array([i for i in range(0, 181, 20)])
self.alt = np.array([5000, 10000, 15000])
self.mach = np.array([0.4, 0.8, 1.2, 1.6])
self.rMax = np.array([
[[5801, 5363, 4706, 3893], [10085, 9335, 8209, 6770], [15840, 14652, 13432, 12243]],
[[5841, 5443, 4760, 3917], [10164, 9427, 8342, 6865], [15959, 14854, 13652, 12414]],
[[5976, 5648, 5003, 4106], [10368, 9771, 8714, 7172], [16297, 15459, 14382, 13090]],
[[6205, 6076, 5438, 4458], [10730, 10418, 9385, 7781], [16825, 16526, 15747, 14421]],
[[6540, 6686, 6142, 4975], [11236, 11393, 10466, 8659], [17603, 18063, 17844, 16640]],
[[6971, 7613, 7379, 6157], [11948, 12845, 12317, 10117], [18577, 20042, 20824, 18580]],
[[7477, 8788, 9253, 8796], [12757, 14680, 15409, 13940], [19649, 22240, 24379, 25333]],
[[8033, 10030, 11692, 12788], [13589, 16664, 19047, 20512], [20660, 24301, 27761, 30947]],
[[8428, 11007, 13549, 16058], [14245, 18212, 22110, 25897], [21407, 25822, 30158, 34456]],
[[8616, 11385, 14299, 17337], [14558, 18840, 23320, 27987], [21751, 26421, 31088, 35786]]
], dtype=np.float64)
def get_rMax(self, aa, alt, mach):
ret = interpn((self.aa, self.alt, self.mach), self.rMax, np.array([aa, alt, mach]))
return ret