图像恢复是一类图形去噪问题的集合,在深度学习中可以理解为监督回归问题,主要包括图像去雨、图像去雾、图像去噪,图像去模糊和图像去马赛克等内容,但利用 pytorch 实现的代码类似,只是在具体网络结构上略有区别。
以图像去雨为例,之前写过一篇图像去雨的 pytorch 实现文章: https://blog.csdn.net/Wenyuanbo/article/details/116541682,但因当时能力和水平有限,实现逻辑存在问题,最近重新整理分享一下,希望能对大家有所帮助,工程文件如图所示,数据集路径根据自己情况设置。
利用监督回归方法实现图像去雨时,一般数据集为有雨图和无雨图成对存在,首先我喜欢习惯性的将所有成对数据分别从 0 到结束对应重新排序(这个其实不影响,具体自己设计即可),诸如 001, 002, 003…。
MyDataset.py
import os
import random
import torchvision.transforms.functional as ttf
from torch.utils.data import Dataset
from PIL import Image
训练数据集是用来整合训练数据的,将有雨图和无雨图分别对应进行剪切,转张量等操作。
class MyTrainDataSet(Dataset): # 训练数据集
def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
super(MyTrainDataSet, self).__init__()
self.inputPath = inputPathTrain
self.inputImages = os.listdir(inputPathTrain) # 输入图片路径下的所有文件名列表
self.targetPath = targetPathTrain
self.targetImages = os.listdir(targetPathTrain) # 目标图片路径下的所有文件名列表
self.ps = patch_size
def __len__(self):
return len(self.targetImages)
def __getitem__(self, index):
ps = self.ps
index = index % len(self.targetImages)
inputImagePath = os.path.join(self.inputPath, self.inputImages[index]) # 图片完整路径
inputImage = Image.open(inputImagePath).convert('RGB') # 读取图片
targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
targetImage = Image.open(targetImagePath).convert('RGB')
inputImage = ttf.to_tensor(inputImage) # 将图片转为张量
targetImage = ttf.to_tensor(targetImage)
hh, ww = targetImage.shape[1], targetImage.shape[2] # 图片的高和宽
rr = random.randint(0, hh-ps) # 随机数: patch 左下角的坐标 (rr, cc)
cc = random.randint(0, ww-ps)
# aug = random.randint(0, 8) # 随机数,对应对图片进行的操作
input_ = inputImage[:, rr:rr+ps, cc:cc+ps] # 裁剪 patch ,输入和目标 patch 要对应相同
target = targetImage[:, rr:rr+ps, cc:cc+ps]
return input_, target
在网络训练中,不一定最后一次训练的效果就是最好的。评估数据集是在每一个 epoch 训练结束后对网络训练的性能进行评估,目的在于将最好的一次训练结果保存。
class MyValueDataSet(Dataset): # 评估数据集
def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
super(MyValueDataSet, self).__init__()
self.inputPath = inputPathTrain
self.inputImages = os.listdir(inputPathTrain) # 输入图片路径下的所有文件名列表
self.targetPath = targetPathTrain
self.targetImages = os.listdir(targetPathTrain) # 目标图片路径下的所有文件名列表
self.ps = patch_size
def __len__(self):
return len(self.targetImages)
def __getitem__(self, index):
ps = self.ps
index = index % len(self.targetImages)
inputImagePath = os.path.join(self.inputPath, self.inputImages[index]) # 图片完整路径
inputImage = Image.open(inputImagePath).convert('RGB') # 读取图片,灰度图
targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
targetImage = Image.open(targetImagePath).convert('RGB')
inputImage = ttf.center_crop(inputImage, (ps, ps))
targetImage = ttf.center_crop(targetImage, (ps, ps))
input_ = ttf.to_tensor(inputImage) # 将图片转为张量
target = ttf.to_tensor(targetImage)
return input_, target
测试数据集的目的是将输入有雨进行去雨得到去雨后的结果,注意输入一般是原图大小,不进行裁剪。
class MyTestDataSet(Dataset): # 测试数据集
def __init__(self, inputPathTest):
super(MyTestDataSet, self).__init__()
self.inputPath = inputPathTest
self.inputImages = os.listdir(inputPathTest) # 输入图片路径下的所有文件名列表
def __len__(self):
return len(self.inputImages) # 路径里的图片数量
def __getitem__(self, index):
index = index % len(self.inputImages)
inputImagePath = os.path.join(self.inputPath, self.inputImages[index]) # 图片完整路径
inputImage = Image.open(inputImagePath).convert('RGB') # 读取图片
input_ = ttf.to_tensor(inputImage) # 将图片转为张量
return input_
以一个 5 层简单卷积神经网络为例子,具体网络自己设定。
NetModel.py
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.inconv = nn.Sequential( # 输入层网络
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True)
)
self.midconv = nn.Sequential( # 中间层网络
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True),
)
self.outconv = nn.Sequential( # 输出层网络
nn.Conv2d(3, 32, 3, 1, 1),
)
def forward(self, x):
x = self.inconv(x)
x = self.midconv(x)
x = self.outconv(x)
return x
自定义工具包主要是一个计算峰值信噪比(PSNR)的方法用来对训练进行评估。
utils.py
import torch
def torchPSNR(tar_img, prd_img):
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
rmse = (imdff**2).mean().sqrt()
ps = 20*torch.log10(1/rmse)
return ps
main.py
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm, trange # 进度条
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
import utils
from NetModel import Net
from MyDataset import *
if __name__ == '__main__': # 只有在 main 中才能开多线程
EPOCH = 100 # 训练次数
BATCH_SIZE = 18 # 每批的训练数量
LEARNING_RATE = 1e-3 # 学习率
loss_list = [] # 损失存储数组
best_psnr = 0 # 训练最好的峰值信噪比
best_epoch = 0 # 峰值信噪比最好时的 epoch
inputPathTrain = 'E://Rain100H/inputTrain/' # 训练输入图片路径
targetPathTrain = 'E://Rain100H/targetTrain/' # 训练目标图片路径
inputPathTest = 'E://Rain100H/inputTest/' # 测试输入图片路径
resultPathTest = 'E://Rain100H/resultTest/' # 测试结果图片路径
targetPathTest = 'E://Rain100H/targetTest/' # 测试目标图片路径
myNet = Net() # 实例化网络
myNet = myNet.cuda() # 网络放入GPU中
criterion = nn.MSELoss().cuda()
optimizer = optim.Adam(myNet.parameters(), lr=LEARNING_RATE) # 网络参数优化算法
# 训练数据
datasetTrain = MyTrainDataSet(inputPathTrain, targetPathTrain) # 实例化训练数据集类
# 可迭代数据加载器加载训练数据
trainLoader = DataLoader(dataset=datasetTrain, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)
# 评估数据
datasetValue = MyValueDataSet(inputPathTest, targetPathTest) # 实例化评估数据集类
valueLoader = DataLoader(dataset=datasetValue, batch_size=16, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)
# 测试数据
datasetTest = MyTestDataSet(inputPathTest) # 实例化测试数据集类
# 可迭代数据加载器加载测试数据
testLoader = DataLoader(dataset=datasetTest, batch_size=1, shuffle=False, drop_last=False, num_workers=6, pin_memory=True)
# 开始训练
print('-------------------------------------------------------------------------------------------------------')
if os.path.exists('./model_best.pth'): # 判断是否预训练
myNet.load_state_dict(torch.load('./model_best.pth'))
for epoch in range(EPOCH):
myNet.train() # 指定网络模型训练状态
iters = tqdm(trainLoader, file=sys.stdout) # 实例化 tqdm,自定义
epochLoss = 0 # 每次训练的损失
timeStart = time.time() # 每次训练开始时间
for index, (x, y) in enumerate(iters, 0):
myNet.zero_grad() # 模型参数梯度置0
optimizer.zero_grad() # 同上等效
input_train, target = Variable(x).cuda(), Variable(y).cuda() # 转为可求导变量并放入 GPU
output_train = myNet(input_train) # 输入网络,得到相应输出
loss = criterion(output_train, target) # 计算网络输出与目标输出的损失
loss.backward() # 反向传播
optimizer.step() # 更新网络参数
epochLoss += loss.item() # 累计一次训练的损失
# 自定义进度条前缀
iters.set_description('Training !!! Epoch %d / %d, Batch Loss %.6f' % (epoch+1, EPOCH, loss.item()))
# 评估
myNet.eval()
psnr_val_rgb = []
for index, (x, y) in enumerate(valueLoader, 0):
input_, target_value = x.cuda(), y.cuda()
with torch.no_grad():
output_value = myNet(input_)
for output_value, target_value in zip(output_value, target_value):
psnr_val_rgb.append(psnr(output_value, target_value))
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
torch.save(myNet.state_dict(), 'model_best.pth')
loss_list.append(epochLoss) # 插入每次训练的损失值
torch.save(myNet.state_dict(), 'model.pth') # 每次训练结束保存模型参数
timeEnd = time.time() # 每次训练结束时间
print("------------------------------------------------------------")
print("Epoch: {} Finished, Time: {:.4f} s, Loss: {:.6f}.".format(epoch+1, timeEnd-timeStart, epochLoss))
print('-------------------------------------------------------------------------------------------------------')
print("Training Process Finished ! Best Epoch : {} , Best PSNR : {:.2f}".format(best_epoch, best_psnr))
# 测试
print('--------------------------------------------------------------')
myNet.load_state_dict(torch.load('./model_best.pth')) # 加载已经训练好的模型参数
myNet.eval() # 指定网络模型测试状态
with torch.no_grad(): # 测试阶段不需要梯度
timeStart = time.time() # 测试开始时间
for index, x in enumerate(tqdm(testLoader, desc='Testing !!! ', file=sys.stdout), 0):
torch.cuda.empty_cache() # 释放显存
input_test = x.cuda() # 放入GPU
output_test = myNet(input_test) # 输入网络,得到输出
save_image(output_test, resultPathTest + str(index+1).zfill(3) + tail) # 保存网络输出结果
timeEnd = time.time() # 测试结束时间
print('---------------------------------------------------------')
print("Testing Process Finished !!! Time: {:.4f} s".format(timeEnd - timeStart))
# 绘制训练时损失曲线
plt.figure(1)
x = range(0, EPOCH)
plt.xlabel('epoch')
plt.ylabel('epoch loss')
plt.plot(x, loss_list, 'r-')
plt.show()
关于图像恢复特别是图像去雨问题欢迎一起交流学习。