pytorch: 图像恢复问题的代码实现详解(derain,dehaze,deblur,denoise等通用)

文章目录

      • 前言
      • 数据集
        • 训练数据集
        • 评估数据集
        • 测试数据集
      • 网络模型
      • 自定义工具包
      • 网络训练和测试
      • 结语

前言

图像恢复是一类图形去噪问题的集合,在深度学习中可以理解为监督回归问题,主要包括图像去雨、图像去雾、图像去噪,图像去模糊和图像去马赛克等内容,但利用 pytorch 实现的代码类似,只是在具体网络结构上略有区别。

以图像去雨为例,之前写过一篇图像去雨的 pytorch 实现文章: https://blog.csdn.net/Wenyuanbo/article/details/116541682,但因当时能力和水平有限,实现逻辑存在问题,最近重新整理分享一下,希望能对大家有所帮助,工程文件如图所示,数据集路径根据自己情况设置。

pytorch: 图像恢复问题的代码实现详解(derain,dehaze,deblur,denoise等通用)_第1张图片

数据集

利用监督回归方法实现图像去雨时,一般数据集为有雨图和无雨图成对存在,首先我喜欢习惯性的将所有成对数据分别从 0 到结束对应重新排序(这个其实不影响,具体自己设计即可),诸如 001, 002, 003…。

MyDataset.py

import os
import random
import torchvision.transforms.functional as ttf
from torch.utils.data import Dataset
from PIL import Image

训练数据集

训练数据集是用来整合训练数据的,将有雨图和无雨图分别对应进行剪切,转张量等操作。

class MyTrainDataSet(Dataset):  # 训练数据集
    def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
        super(MyTrainDataSet, self).__init__()

        self.inputPath = inputPathTrain
        self.inputImages = os.listdir(inputPathTrain)  # 输入图片路径下的所有文件名列表

        self.targetPath = targetPathTrain
        self.targetImages = os.listdir(targetPathTrain)  # 目标图片路径下的所有文件名列表

        self.ps = patch_size

    def __len__(self):
        return len(self.targetImages)

    def __getitem__(self, index):

        ps = self.ps
        index = index % len(self.targetImages)

        inputImagePath = os.path.join(self.inputPath, self.inputImages[index])  # 图片完整路径
        inputImage = Image.open(inputImagePath).convert('RGB')  # 读取图片

        targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
        targetImage = Image.open(targetImagePath).convert('RGB')

        inputImage = ttf.to_tensor(inputImage)  # 将图片转为张量
        targetImage = ttf.to_tensor(targetImage)

        hh, ww = targetImage.shape[1], targetImage.shape[2]  # 图片的高和宽

        rr = random.randint(0, hh-ps)  # 随机数: patch 左下角的坐标 (rr, cc)
        cc = random.randint(0, ww-ps)
        # aug = random.randint(0, 8)  # 随机数,对应对图片进行的操作

        input_ = inputImage[:, rr:rr+ps, cc:cc+ps]  # 裁剪 patch ,输入和目标 patch 要对应相同
        target = targetImage[:, rr:rr+ps, cc:cc+ps]

        return input_, target

评估数据集

在网络训练中,不一定最后一次训练的效果就是最好的。评估数据集是在每一个 epoch 训练结束后对网络训练的性能进行评估,目的在于将最好的一次训练结果保存。

class MyValueDataSet(Dataset):  # 评估数据集
    def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
        super(MyValueDataSet, self).__init__()

        self.inputPath = inputPathTrain
        self.inputImages = os.listdir(inputPathTrain)  # 输入图片路径下的所有文件名列表

        self.targetPath = targetPathTrain
        self.targetImages = os.listdir(targetPathTrain)  # 目标图片路径下的所有文件名列表

        self.ps = patch_size

    def __len__(self):
        return len(self.targetImages)

    def __getitem__(self, index):

        ps = self.ps
        index = index % len(self.targetImages)

        inputImagePath = os.path.join(self.inputPath, self.inputImages[index])  # 图片完整路径
        inputImage = Image.open(inputImagePath).convert('RGB')  # 读取图片,灰度图

        targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
        targetImage = Image.open(targetImagePath).convert('RGB')

        inputImage = ttf.center_crop(inputImage, (ps, ps))
        targetImage = ttf.center_crop(targetImage, (ps, ps))

        input_ = ttf.to_tensor(inputImage)  # 将图片转为张量
        target = ttf.to_tensor(targetImage)

        return input_, target

测试数据集

测试数据集的目的是将输入有雨进行去雨得到去雨后的结果,注意输入一般是原图大小,不进行裁剪。

class MyTestDataSet(Dataset):  # 测试数据集
    def __init__(self, inputPathTest):
        super(MyTestDataSet, self).__init__()

        self.inputPath = inputPathTest
        self.inputImages = os.listdir(inputPathTest)  # 输入图片路径下的所有文件名列表

    def __len__(self):
        return len(self.inputImages)  # 路径里的图片数量

    def __getitem__(self, index):
        index = index % len(self.inputImages)

        inputImagePath = os.path.join(self.inputPath, self.inputImages[index])  # 图片完整路径
        inputImage = Image.open(inputImagePath).convert('RGB')  # 读取图片

        input_ = ttf.to_tensor(inputImage)  # 将图片转为张量

        return input_

网络模型

以一个 5 层简单卷积神经网络为例子,具体网络自己设定。
NetModel.py

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.inconv = nn.Sequential(  # 输入层网络
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True)
        )
        self.midconv = nn.Sequential(  # 中间层网络
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
        )
        self.outconv = nn.Sequential(  # 输出层网络
            nn.Conv2d(3, 32, 3, 1, 1),
        )
        
    def forward(self, x):

        x = self.inconv(x)
        x = self.midconv(x)
        x = self.outconv(x)
        
        return x

自定义工具包

自定义工具包主要是一个计算峰值信噪比(PSNR)的方法用来对训练进行评估。

utils.py

import torch

def torchPSNR(tar_img, prd_img):
    imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
    rmse = (imdff**2).mean().sqrt()
    ps = 20*torch.log10(1/rmse)
    return ps

网络训练和测试

main.py

import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm, trange  # 进度条
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
import utils
from NetModel import Net
from MyDataset import *

if __name__ == '__main__':  # 只有在 main 中才能开多线程
	EPOCH = 100  # 训练次数
    BATCH_SIZE = 18  # 每批的训练数量
    LEARNING_RATE = 1e-3  # 学习率
    loss_list = []  # 损失存储数组
	best_psnr = 0  # 训练最好的峰值信噪比
    best_epoch = 0  # 峰值信噪比最好时的 epoch
	
	inputPathTrain = 'E://Rain100H/inputTrain/'  # 训练输入图片路径
    targetPathTrain = 'E://Rain100H/targetTrain/'  # 训练目标图片路径
    inputPathTest = 'E://Rain100H/inputTest/'  # 测试输入图片路径
    resultPathTest = 'E://Rain100H/resultTest/'  # 测试结果图片路径
    targetPathTest = 'E://Rain100H/targetTest/'  # 测试目标图片路径

	myNet = Net()  # 实例化网络
    myNet = myNet.cuda()  # 网络放入GPU中
	criterion = nn.MSELoss().cuda()
	
    optimizer = optim.Adam(myNet.parameters(), lr=LEARNING_RATE)  # 网络参数优化算法

    # 训练数据
    datasetTrain = MyTrainDataSet(inputPathTrain, targetPathTrain)  # 实例化训练数据集类
    # 可迭代数据加载器加载训练数据
    trainLoader = DataLoader(dataset=datasetTrain, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)

    # 评估数据
    datasetValue = MyValueDataSet(inputPathTest, targetPathTest)  # 实例化评估数据集类
    valueLoader = DataLoader(dataset=datasetValue, batch_size=16, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)

	# 测试数据
    datasetTest = MyTestDataSet(inputPathTest)  # 实例化测试数据集类
    # 可迭代数据加载器加载测试数据
    testLoader = DataLoader(dataset=datasetTest, batch_size=1, shuffle=False, drop_last=False, num_workers=6, pin_memory=True)

	# 开始训练
    print('-------------------------------------------------------------------------------------------------------')
    if os.path.exists('./model_best.pth'):  # 判断是否预训练
        myNet.load_state_dict(torch.load('./model_best.pth'))

    for epoch in range(EPOCH):
        myNet.train()  # 指定网络模型训练状态
        iters = tqdm(trainLoader, file=sys.stdout)  # 实例化 tqdm,自定义
        epochLoss = 0  # 每次训练的损失
        timeStart = time.time()  # 每次训练开始时间
        for index, (x, y) in enumerate(iters, 0):

            myNet.zero_grad()  # 模型参数梯度置0
            optimizer.zero_grad()  # 同上等效

            input_train, target = Variable(x).cuda(), Variable(y).cuda()  # 转为可求导变量并放入 GPU
            output_train = myNet(input_train)  # 输入网络,得到相应输出

            loss = criterion(output_train, target) # 计算网络输出与目标输出的损失

            loss.backward()  # 反向传播
            optimizer.step()  # 更新网络参数
            epochLoss += loss.item()  # 累计一次训练的损失

            # 自定义进度条前缀
            iters.set_description('Training !!!  Epoch %d / %d,  Batch Loss %.6f' % (epoch+1, EPOCH, loss.item()))

        # 评估
        myNet.eval()
        psnr_val_rgb = []
        for index, (x, y) in enumerate(valueLoader, 0):
            input_, target_value = x.cuda(), y.cuda()
            with torch.no_grad():
                output_value = myNet(input_)
            for output_value, target_value in zip(output_value, target_value):
                psnr_val_rgb.append(psnr(output_value, target_value))

        psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()

        if psnr_val_rgb > best_psnr:
            best_psnr = psnr_val_rgb
            best_epoch = epoch
            torch.save(myNet.state_dict(), 'model_best.pth')

        loss_list.append(epochLoss)  # 插入每次训练的损失值
        torch.save(myNet.state_dict(), 'model.pth')  # 每次训练结束保存模型参数
        timeEnd = time.time()  # 每次训练结束时间
        print("------------------------------------------------------------")
        print("Epoch:  {}  Finished,  Time:  {:.4f} s,  Loss:  {:.6f}.".format(epoch+1, timeEnd-timeStart, epochLoss))
        print('-------------------------------------------------------------------------------------------------------')
    print("Training Process Finished ! Best Epoch : {} , Best PSNR : {:.2f}".format(best_epoch, best_psnr))

	# 测试
    print('--------------------------------------------------------------')
    myNet.load_state_dict(torch.load('./model_best.pth'))  # 加载已经训练好的模型参数
    myNet.eval()  # 指定网络模型测试状态

    with torch.no_grad():  # 测试阶段不需要梯度
        timeStart = time.time()  # 测试开始时间
        for index, x in enumerate(tqdm(testLoader, desc='Testing !!! ', file=sys.stdout), 0):
            torch.cuda.empty_cache()  # 释放显存
            input_test = x.cuda()  # 放入GPU
            output_test = myNet(input_test)  # 输入网络,得到输出
            save_image(output_test, resultPathTest + str(index+1).zfill(3) + tail)  # 保存网络输出结果
        timeEnd = time.time()  # 测试结束时间
        print('---------------------------------------------------------')
        print("Testing Process Finished !!! Time: {:.4f} s".format(timeEnd - timeStart))

	# 绘制训练时损失曲线
    plt.figure(1)
    x = range(0, EPOCH)
    plt.xlabel('epoch')
    plt.ylabel('epoch loss')
    plt.plot(x, loss_list, 'r-')
    plt.show()

结语

关于图像恢复特别是图像去雨问题欢迎一起交流学习。

你可能感兴趣的:(pytorch,深度学习,pytorch,深度学习,python,人工智能,神经网络)