图像去雨:超详细手把手写 pytorch 实现代码(带注释)

引导

  • 数据集准备
  • 训练数据集代码
  • 测试数据集代码
  • 网络模型代码
  • 训练代码
  • 测试代码
  • 参考文献
  • 其他

数据集准备

使用来自第一个参考文献的公开数据集Rain12600和Rain1400,下载链接。其中训练图像900张,测试图像100张,分别有14张不同的雨图,因此训练集共12600对,测试集共1400对。为方便理解提前对干净图像各自复制14张,并按照顺序训练集从00001-12600互相对应,测试集从0001-1400互相对应。
图像去雨:超详细手把手写 pytorch 实现代码(带注释)_第1张图片

训练数据集代码

DataTrain.py

import os
import torchvision
from torch.utils.data import  Dataset
from PIL import Image

class MyTrainDataset(Dataset):
    def __init__(self, input_path, label_path):
        self.input_path = input_path
        self.input_files = os.listdir(input_path)

        self.label_path = label_path
        self.label_files = os.listdir(label_path)

        self.transforms = torchvision.transforms.Compose([
            torchvision.transforms.CenterCrop([64, 64]),
            torchvision.transforms.ToTensor(),
            ])

    def __len__(self):
        return len(self.input_files)
    def __getitem__(self, index):
        input_image_path = os.path.join(self.input_path, self.input_files[index])
        input_image = Image.open(input_image_path).convert('RGB')

        label_image_path = os.path.join(self.label_path, self.label_files[index])
        label_image = Image.open(label_image_path).convert('RGB')

        input = self.transforms(input_image)
        label = self.transforms(label_image)

        return  (input, label)

测试数据集代码

DataTest.py

import os
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class MyTestDataset(Dataset):
    def __init__(self, input_path): 
        super(MyTestDataset, self).__init__()
        self.input_path = input_path
        self.input_files = os.listdir(self.input_path)
        self.transforms = transforms.Compose([
            # transforms.CenterCrop([128, 128]),# 这行没有必要
            transforms.ToTensor(),
            ])

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, index):
        
        input_image_path = os.path.join(self.input_path, self.input_files[index])
        input_image = Image.open(input_image_path).convert('RGB')
        input = self.transforms(input_image)

        return input

网络模型代码

NetModel.py
基于PRN网络做一个简单示意,网络模型可以根据需要改变。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2d(6, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv1 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv2 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv3 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv4 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv5 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.conv = nn.Sequential(
            nn.Conv2d(32, 3, 3, 1, 1),
        )
    def forward(self, input):
        
        x = input
        for i in range(6):# 迭代次数,不改变网络参数量
            
            x = torch.cat((input, x), 1)
            x = self.conv0(x)
            x = F.relu(self.res_conv1(x) + x)
            x = F.relu(self.res_conv2(x) + x)
            x = F.relu(self.res_conv3(x) + x)
            x = F.relu(self.res_conv4(x) + x)
            x = F.relu(self.res_conv5(x) + x)
            x = self.conv(x)
            x = x + input
            
        return x

训练代码

Train.py

import torch
import torch.optim as optim
from NetModel import Net
import torch.nn as nn
import os
from DataTrain import MyTrainDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
## matplotlib显示图片中显示汉字
# plt.rcParams['font.sans-serif'] = ['SimSun'] 
# plt.rcParams['axes.unicode_minus'] = False

# 训练图像的路径
input_path = 'F://imagePreprocess/train/input/'
label_path = 'F://imagePreprocess/train/label/'
net = Net().cuda()

learning_rate = 1e-3
batch_size = 50# 分批训练数据,每批数据量
epoch = 100 # 训练次数
# Loss_list = [] # 简单的显示损失曲线列表,反注释后训练完显示曲线

optimizer = optim.Adam(net.parameters(), lr=learning_rate)
loss_f = nn.MSELoss()
net.train()

if os.path.exists('./model.pth'):# 判断模型有没有提前训练过
    print("继续训练!")
    net.load_state_dict(torch.load('./model.pth'))# 加载训练过的模型
else:
    print("从头训练!")

for i in range(epoch):
    dataset_train = MyTrainDataset(input_path, label_path)
    trainloader = DataLoader(dataset_train, batch_size=batch_size,shuffle=True)

    for j, (x, y) in enumerate(trainloader):# 加载训练数据
        input = Variable(x).cuda()
        label = Variable(y).cuda()

        net.zero_grad()
        optimizer.zero_grad()

        output = net(input)
        loss = loss_f(output, label)

        optimizer.zero_grad()
        loss.backward() # 反向传播
        optimizer.step()

        print("已完成第{}次训练的{:.3f}%,目前损失值为{:.6f}。".format(i+1, ((j+1)/252)*100, loss))

        # Loss_list.append(loss)

        if j%9 == 0:
            torch.save(net.state_dict(), 'model.pth') # 保存训练模型

# plt.figure(dpi=500)
# x = range(0, 2520*2)
# y = Loss_list
# plt.plot(x, y, 'r-')
# plt.ylabel('当前损失/1')
# plt.xlabel('批训练次数/次数')
# plt.savefig('F://loss.jpg')
# plt.show()

测试代码

Test.py

import torch
from NetModel import Net
from DataTest import MyTestDataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image

# 测试图像的路径
input_path = 'F://imagePreprocess/test/input/'

net = Net().cuda()
net.load_state_dict(torch.load('./model.pth')) # 加载训练好的模型参数
net.eval()

cnt = 0

dataloader = DataLoader(MyTestDataset(input_path))
for input in dataloader:
    cnt += 1
    input = input.cuda()

    print('finished:{:.2f}%'.format(cnt*100/1400))

    with torch.no_grad():
        output_image = net(input) # 输出的是张量
        save_image(output_image, 'F://imagePreprocess/test/result'+str(cnt).zfill(4)+'.jpg') # 直接保存张量图片,自动转换

参考文献

  1. Xueyang Fu, Jiabin Huang, Delu Zeng, et al. Removing Rain From Single Images via a Deep Detail Network[C]. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2017:3855-3863.
  2. Dongwei Ren, Wangmeng Zuo, Qinghua Hu, et al. Progressive Image Deraining Networks: A Better and Simpler Baseline[C]. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019:3937-3946.

其他

  1. 具体数据集的制作解释参考制作输入和标签都是图片的数据集
  2. 直接将代码放在一个工程路径下会有导入不成功问题,同上参考

你可能感兴趣的:(笔记,深度学习,python,opencv,人工智能,深度学习,计算机视觉)