长时序栅格数据经常会出现一些缺失值,会对后续的分析造成很大的不便。这便需要利用一些插值算法对这些缺失数据进行填补,奇异谱分析(SSA)便是常用的一种插值方法。更多内容可见公众号GeodataAnalysis
。
在时间序列分析中,「奇异谱分析」(「SSA」)是一种非参数谱估计方法。它结合了经典时间序列分析、多元统计、多元几何、动力系统和信号处理的元素。
“奇异谱分析”这个名称涉及协方差矩阵的奇异值分解中的特征值谱,而不是直接涉及频域分解。
SSA 可以帮助分解时间序列分解为组件的总和,每个组件都有有意义的解释。如下图所示,奇异谱分析分解出来了趋势、变化和噪声三部分。
SSA只考虑数据本身的特征,不考虑其他因素,特别适合于插补、平稳时间序列的预测。
导入所需的第三方库
import os
import calendar
import numpy as np
import rasterio as rio
import pandas as pd
from mssa.mssa import mSSA
import matplotlib.pyplot as plt
生成测试数据
x = np.arange(365)
y = np.sin(x * np.pi * 2 / 365) + np.random.randn(x.size) * 0.2
y[:5] = np.nan
y[-30:] = np.nan
y[90:95] = np.nan
y[200:205] = np.nan
y[300:305] = np.nan
plt.plot(x, y);
SSA插补
df = pd.DataFrame(data=y, columns=['data'])
model = mSSA(rank=None, fill_in_missing = True)
model.update_model(df)
df2 = model.predict('data', 0, df.shape[0]-1)
mask = np.isnan(y)
plt.plot(x, y)
plt.scatter(x[mask], pred[mask], c='r', s=1);
通过分析预测结果可知,时间序列的末尾若缺失值超过十个,最后五个的缺失值便无法预测。这里采用的解决办法为,用预测值再进行预测。具体代码如下:
def fill_value(data):
def _fill_value(data):
df = pd.DataFrame(data=data, columns=['data'])
model = mSSA(rank=None, fill_in_missing = True)
model.update_model(df)
df2 = model.predict('data', 0, df.shape[0]-1)
pred = df2.loc[:, 'Mean Predictions'].values
return pred
pred = _fill_value(data)
while True:
mask = np.isnan(pred)
if np.any(mask):
pred2 = _fill_value(pred)
pred[mask] = pred2[mask]
else:
break
return pred
pred = fill_value(y)
mask = np.isnan(y)
plt.plot(x, y);
plt.plot(x, pred);
plt.scatter(x[mask], pred[mask], c='r', s=3);
数据保存格式如下:
data
20170101.tif
20170102.tif
··· ···
20171231.tif
paths = []
years = [2017]
months = ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12']
root_dir = './data'
for year in years:
for month in months:
day_num = calendar.monthrange(year, int(month))[1]
for day in range(1, day_num+1):
path = f'{root_dir}/{year}{month}%02d.tif' % day
paths.append(path)
若某个栅格数据不存在,则将其视为全部是缺失值。
def check_data(paths):
'''检查所有数据文件是否横列数全部相同,同时返回其行列数'''
# 以第一个路径为基准
src = rio.open(paths[0])
RasterXSize, RasterYSize = src.width, src.height
gt = src.transform
proj = src.crs
for i in range(1, len(paths)):
if not os.path.exists(paths[i]):
continue
src = rio.open(paths[i])
assert RasterXSize == src.width
assert RasterYSize == src.height
assert gt == src.transform
assert proj == src.crs
info = {
'RasterXSize': RasterXSize,
'RasterYSize': RasterYSize,
'shape': (RasterXSize, RasterYSize),
'gt': gt,
'proj': proj
}
return info
def rasters_to_array(paths, band_num=1, win=None):
'''读取所有数据文件,并转换为Numpy数组输出'''
data_info = check_data(paths)
for i in range(len(paths)):
path = paths[i]
if os.path.exists(path):
# 读取数据
src = rio.open(path)
array = src.read(band_num, window=win)
# 设置数据的NoData值
nodata = src.nodata
array[array==nodata] = np.nan
else:
# 若第一个路径为空,可自行设置数据类型
array = np.full(shape=(data_info['RasterYSize'], data_info['RasterXSize']),
fill_value=np.nan, dtype=array.dtype)
# 拼接数组
if 0 == i:
out_array = array[np.newaxis, :, :]
else:
out_array = np.concatenate((out_array, array[np.newaxis, :, :]), axis=0)
return out_array, data_info
array, data_info = rasters_to_array(paths)
def array_to_tif(out_path, arr, crs, transform, nodata=None):
# 获取数组的形状
if arr.ndim==2:
count = 1
height, width = arr.shape
elif arr.ndim==3:
count = arr.shape[0]
_, height, width = arr.shape
else:
raise ValueError
with rio.open(out_path, 'w',
driver='GTiff',
height=height, width=width,
count=count,
dtype=arr.dtype,
crs=crs,
transform=transform,
nodata=nodata) as dst:
# 写入数据到输出文件
if count==1:
dst.write(arr, 1)
else:
for i in range(count):
dst.write(arr[i, ...], i+1)
fill_array = array.copy()
# 最大缺失值数量
thre = 65
for i in range(array.shape[1]):
for j in range(array.shape[2]):
data = array[:, i, j]
mask = np.isnan(data)
if np.all(mask):
continue
if np.count_nonzero(mask)>thre:
continue
index = np.arange(data.size)
pred = fill_value(data)
fill_array[index, i, j] = pred
out_dir = './results'
for i, path in enumerate(paths):
file_name = os.path.basename(path)
out_path = os.path.join(out_dir, file_name)
array_to_tif(out_path, fill_array[i, ...],
data_info['proj'], data_info['gt'],
nodata=np.nan)