博主在处理图片尺度问题时,习惯使用 cv2.resize() 函数;但当图片数据需用显卡加速运算时,数据需要在 GPU 和 CPU 之间不断迁移,导致程序运行效率降低;
Pytorch 提供了一个类似于 cv2.resize() 的采样函数,即 torch.nn.functional.interpolate(),支持最近邻插值(nearest)和双线性插值(bilinear)等功能,通过设置合理的插值方式可以取得与 cv2.resize() 函数完全一样的效果。
① 最近邻方法('nearnest' 和 cv2.INTER_NEAREST):
import torch
import cv2
import torch.nn.functional as F
import numpy as np
input_data1 = torch.randint(low = 0, high = 255, size = [40, 40, 3])
input_data2 = np.array(input_data1, dtype = np.uint8)
input_data1 = input_data1.permute(2, 0, 1).unsqueeze(0).float() # [1, 3, 40, 40]
output_data1 = F.interpolate(input_data1, size = (224, 224), mode='nearest').float() # [1, 3, 224, 224]
output_data2 = cv2.resize(input_data2, dsize = (224, 224), interpolation=cv2.INTER_NEAREST) # [224, 224, 3]
data1 = np.array(output_data1.squeeze(0).permute(1, 2, 0), dtype=np.uint8)
data2 = np.array(output_data2, dtype=np.uint8)
print(data1 == data2)
print("All done !")
② 双线性插值方法('bilinear' 和 cv2.INTER_LINEAR):
import torch
import cv2
import torch.nn.functional as F
import numpy as np
input_data1 = torch.randint(low = 0, high = 255, size = [40, 40, 3])
input_data2 = np.array(input_data1, dtype = np.uint8)
input_data1 = input_data1.permute(2, 0, 1).unsqueeze(0).float() # [1, 3, 40, 40]
output_data1 = F.interpolate(input_data1, size = (224, 224), mode='bilinear').float() # [1, 3, 224, 224]
output_data2 = cv2.resize(input_data2, dsize = (224, 224), interpolation=cv2.INTER_LINEAR) # [224, 224, 3]
data1 = np.array(output_data1.squeeze(0).permute(1, 2, 0), dtype=np.uint8)
data2 = np.array(output_data2, dtype=np.uint8)
print(data1 == data2)
print("All done !")
上面两个测试代码的结果表明,在采取相同插值方式的前提下,torch.nn.functional.interpolate() 和 cv2.resize() 两个方法的功能是完全等价的,处理后的数据相同;
① 使用 torch.nn.functional.interpolate()的注意事项:
1. 插值方法(mode)与输入数据的维度(minibatch, channels, [optional depth], [optional height], width)密切相关,目前支持的数据维度有以下几种:
① 3D张量输入:minibatch, channels, width;
② 4D张量输入:minibatch, channels, height, width;
③ 5D张量输入:minibatch, channels, depth, height, width;
2. 插值方法和输入维度的关系如下: