这份代码实现的是numpy.ndarray的快速最近邻插值(放缩)。而这个方法貌似并没有直接的API(不能使用图像放缩API,因为数值会放缩到0-255)。
我们由3*3的数组如下:
[[1 2 3]
[4 5 6]
[7 8 9]]
然后我们想要使用最近邻插值,放缩其成为一个8*10的矩阵,如下:
[[1 1 1 1 2 2 2 3 3 3]
[1 1 1 1 2 2 2 3 3 3]
[1 1 1 1 2 2 2 3 3 3]
[4 4 4 4 5 5 5 6 6 6]
[4 4 4 4 5 5 5 6 6 6]
[4 4 4 4 5 5 5 6 6 6]
[7 7 7 7 8 8 8 9 9 9]
[7 7 7 7 8 8 8 9 9 9]]
显然我们可以使用相当简单的两重for循环实现,但是这种实现极其耗时,尤其是数组特别大的时候。
使用图像处理的API进行插值,会使得所有矩阵数值量化到0-255。
scipy.ndimage
库提供了一个API叫 zoom
(在代码里面也提供了),但是会得到如下的结果,这显然不是我们想要的。
[[1 1 1 2 2 2 2 3 3 3]
[1 1 1 2 2 2 2 3 3 3]
[4 4 4 5 5 5 5 6 6 6]
[4 4 4 5 5 5 5 6 6 6]
[4 4 4 5 5 5 5 6 6 6]
[4 4 4 5 5 5 5 6 6 6]
[7 7 7 8 8 8 8 9 9 9]
[7 7 7 8 8 8 8 9 9 9]]
ndarray_nearest_neighbour_scaling
为本文实现的算法ndarray_zoom_scaling
为使用 scipy.ndimage.zoom
实现的算法import numpy as np
import scipy.ndimage
def ndarray_zoom_scaling(label, new_h, new_w):
"""
Implement scaling for ndarray with scipy.ndimage.zoom
:param label: [H, W] or [H, W, C]
:return: label_new: [new_h, new_w] or [new_h, new_w, C]
Examples
--------
ori_arr = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=np.int32)
new_arr = ndarray_zoom_scaling(ori_arr, new_h=8, new_w=10)
>> print(new_arr)
[[1 1 1 2 2 2 2 3 3 3]
[1 1 1 2 2 2 2 3 3 3]
[4 4 4 5 5 5 5 6 6 6]
[4 4 4 5 5 5 5 6 6 6]
[4 4 4 5 5 5 5 6 6 6]
[4 4 4 5 5 5 5 6 6 6]
[7 7 7 8 8 8 8 9 9 9]
[7 7 7 8 8 8 8 9 9 9]]
"""
scale_h = new_h / label.shape[0]
scale_w = new_w / label.shape[1]
if len(label.shape) == 2:
label_new = scipy.ndimage.zoom(label, zoom=[scale_h, scale_w], order=0)
else:
label_new = scipy.ndimage.zoom(label, zoom=[scale_h, scale_w, 1], order=0)
return label_new
def ndarray_nearest_neighbour_scaling(label, new_h, new_w):
"""
Implement nearest neighbour scaling for ndarray
:param label: [H, W] or [H, W, C]
:return: label_new: [new_h, new_w] or [new_h, new_w, C]
Examples
--------
ori_arr = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=np.int32)
new_arr = ndarray_nearest_neighbour_scaling(ori_arr, new_h=8, new_w=10)
>> print(new_arr)
[[1 1 1 1 2 2 2 3 3 3]
[1 1 1 1 2 2 2 3 3 3]
[1 1 1 1 2 2 2 3 3 3]
[4 4 4 4 5 5 5 6 6 6]
[4 4 4 4 5 5 5 6 6 6]
[4 4 4 4 5 5 5 6 6 6]
[7 7 7 7 8 8 8 9 9 9]
[7 7 7 7 8 8 8 9 9 9]]
"""
if len(label.shape) == 2:
label_new = np.zeros([new_h, new_w], dtype=label.dtype)
else:
label_new = np.zeros([new_h, new_w, label.shape[2]], dtype=label.dtype)
scale_h = new_h / label.shape[0]
scale_w = new_w / label.shape[1]
y_pos = np.arange(new_h)
x_pos = np.arange(new_w)
y_pos = np.floor(y_pos / scale_h).astype(np.int32)
x_pos = np.floor(x_pos / scale_w).astype(np.int32)
y_pos = y_pos.reshape(y_pos.shape[0], 1)
y_pos = np.tile(y_pos, (1, new_w))
x_pos = np.tile(x_pos, (new_h, 1))
assert y_pos.shape == x_pos.shape
label_new[:, :] = label[y_pos[:, :], x_pos[:, :]]
return label_new
比较了三个算法(两重for循环、scipy.ndimage.zoom
算法、本文实现的算法)在10,000次的插值操作后的总耗时,如下:
算法 | 总耗时 | 是我们方法耗时的倍数 |
---|---|---|
本文的方法 | 0.360s | / |
scipy.ndimage.zoom 算法 |
0.436s | 1.21倍 |
两重for循环 | 1.523s | 4.23倍 |
显然本文的算法在速度上比两重for循环快很多。