使用来自第一个参考文献的公开数据集Rain12600和Rain1400,下载链接。其中训练图像900张,测试图像100张,分别有14张不同的雨图,因此训练集共12600对,测试集共1400对。为方便理解提前对干净图像各自复制14张,并按照顺序训练集从00001-12600互相对应,测试集从0001-1400互相对应。
DataTrain.py
import os
import torchvision
from torch.utils.data import Dataset
from PIL import Image
class MyTrainDataset(Dataset):
def __init__(self, input_path, label_path):
self.input_path = input_path
self.input_files = os.listdir(input_path)
self.label_path = label_path
self.label_files = os.listdir(label_path)
self.transforms = torchvision.transforms.Compose([
torchvision.transforms.CenterCrop([64, 64]),
torchvision.transforms.ToTensor(),
])
def __len__(self):
return len(self.input_files)
def __getitem__(self, index):
input_image_path = os.path.join(self.input_path, self.input_files[index])
input_image = Image.open(input_image_path).convert('RGB')
label_image_path = os.path.join(self.label_path, self.label_files[index])
label_image = Image.open(label_image_path).convert('RGB')
input = self.transforms(input_image)
label = self.transforms(label_image)
return (input, label)
DataTest.py
import os
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class MyTestDataset(Dataset):
def __init__(self, input_path):
super(MyTestDataset, self).__init__()
self.input_path = input_path
self.input_files = os.listdir(self.input_path)
self.transforms = transforms.Compose([
# transforms.CenterCrop([128, 128]),# 这行没有必要
transforms.ToTensor(),
])
def __len__(self):
return len(self.input_files)
def __getitem__(self, index):
input_image_path = os.path.join(self.input_path, self.input_files[index])
input_image = Image.open(input_image_path).convert('RGB')
input = self.transforms(input_image)
return input
NetModel.py
基于PRN网络做一个简单示意,网络模型可以根据需要改变。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2d(6, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv1 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv2 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv3 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv4 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv5 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.Conv2d(32, 3, 3, 1, 1),
)
def forward(self, input):
x = input
for i in range(6):# 迭代次数,不改变网络参数量
x = torch.cat((input, x), 1)
x = self.conv0(x)
x = F.relu(self.res_conv1(x) + x)
x = F.relu(self.res_conv2(x) + x)
x = F.relu(self.res_conv3(x) + x)
x = F.relu(self.res_conv4(x) + x)
x = F.relu(self.res_conv5(x) + x)
x = self.conv(x)
x = x + input
return x
Train.py
import torch
import torch.optim as optim
from NetModel import Net
import torch.nn as nn
import os
from DataTrain import MyTrainDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
## matplotlib显示图片中显示汉字
# plt.rcParams['font.sans-serif'] = ['SimSun']
# plt.rcParams['axes.unicode_minus'] = False
# 训练图像的路径
input_path = 'F://imagePreprocess/train/input/'
label_path = 'F://imagePreprocess/train/label/'
net = Net().cuda()
learning_rate = 1e-3
batch_size = 50# 分批训练数据,每批数据量
epoch = 100 # 训练次数
# Loss_list = [] # 简单的显示损失曲线列表,反注释后训练完显示曲线
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
loss_f = nn.MSELoss()
net.train()
if os.path.exists('./model.pth'):# 判断模型有没有提前训练过
print("继续训练!")
net.load_state_dict(torch.load('./model.pth'))# 加载训练过的模型
else:
print("从头训练!")
for i in range(epoch):
dataset_train = MyTrainDataset(input_path, label_path)
trainloader = DataLoader(dataset_train, batch_size=batch_size,shuffle=True)
for j, (x, y) in enumerate(trainloader):# 加载训练数据
input = Variable(x).cuda()
label = Variable(y).cuda()
net.zero_grad()
optimizer.zero_grad()
output = net(input)
loss = loss_f(output, label)
optimizer.zero_grad()
loss.backward() # 反向传播
optimizer.step()
print("已完成第{}次训练的{:.3f}%,目前损失值为{:.6f}。".format(i+1, ((j+1)/252)*100, loss))
# Loss_list.append(loss)
if j%9 == 0:
torch.save(net.state_dict(), 'model.pth') # 保存训练模型
# plt.figure(dpi=500)
# x = range(0, 2520*2)
# y = Loss_list
# plt.plot(x, y, 'r-')
# plt.ylabel('当前损失/1')
# plt.xlabel('批训练次数/次数')
# plt.savefig('F://loss.jpg')
# plt.show()
Test.py
import torch
from NetModel import Net
from DataTest import MyTestDataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
# 测试图像的路径
input_path = 'F://imagePreprocess/test/input/'
net = Net().cuda()
net.load_state_dict(torch.load('./model.pth')) # 加载训练好的模型参数
net.eval()
cnt = 0
dataloader = DataLoader(MyTestDataset(input_path))
for input in dataloader:
cnt += 1
input = input.cuda()
print('finished:{:.2f}%'.format(cnt*100/1400))
with torch.no_grad():
output_image = net(input) # 输出的是张量
save_image(output_image, 'F://imagePreprocess/test/result'+str(cnt).zfill(4)+'.jpg') # 直接保存张量图片,自动转换