飞桨高层API实现图像去雨

高层API实现图像去雨

原文链接:https://aistudio.baidu.com/aistudio/projectdetail/2340714

飞桨高层API实现图像去雨_第1张图片

1 简要介绍

去雨深度模型越来越复杂多样,难以分析不同网络模型的作用。通过参考原始论文[Progressive Image Deraining Networks: A Better and Simpler Baseline],我选择了一个简单的PRN模型对下雨图像进行深度学习和处理。

2 环境设置

本教程基于Paddle 2.1 编写,如果您的环境不是本版本,请先参考官网安装 Paddle 2.1 。

import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import paddle
import paddle.nn as nn
import paddle.optimizer as optim
import paddle.nn.functional as F
from paddle.vision import transforms
from paddle.io import Dataset,DataLoader
paddle.__version__
'2.1.2'

3 数据集

3.1 数据集介绍

  • 该数据集包含 1,000 张干净的图像。每个干净的图像用于生成 14 个具有不同条纹方向和强度的雨天图像。

  • 随机选择 900 张图像进行“训练”,其余图像用于“测试”。

  • 图像是使用 PhotoShop 生成的雨天图像: http://www.photoshopessentials.com/photo-effects/rain/

  • 数据集结构

rainy_image_dataset/ # 图像去雨的根目录
|--training/ # 训练集的文件夹
|  |--ground_truth/ # 没下雨的图片用作标签
|  |  |--1.jpg
|  |  |--2.jpg
|  |  |--...
|  |--rainy_image/ # 生成的14个具有不同条纹方向的强度的雨天图像
|  |  |--1_1.jpg
|  |  |--1_2.jpg
|  |  |--...
|  |  |--1_14.jpg
|  |  |--...
|--testing/ # 测试集的文件夹
|  |--ground_truth/
|  |...
|  |--rainy_image/
|  |...
|

3.2 数据可视化

!unzip data/data107170/rainy_image_dataset.zip
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(10, 10))
img_truth=Image.open('rainy_image_dataset/training/ground_truth/9.jpg')
plt.subplot(1,2, 1)
plt.title('truth')
plt.imshow(img_truth)

img_rainy = Image.open('rainy_image_dataset/training/rainy_image/9_8.jpg')
plt.subplot(1,2, 2)
plt.title('tainy')
plt.imshow(img_rainy)

plt.show()

飞桨高层API实现图像去雨_第2张图片

3.3 自定义数据集

使用飞桨框架高层API的 paddle.io.Dataset 自定义数据集类,具体可以参考官网文档 自定义数据集。自定义的数据集要重写 __init__ ,并实现 __getitem____len__。另外还进行了以下操作:

  • 对训练集图像进行中心裁剪64*64大小作为pacth
  • 将训练集读入的数据归一化到[0, 1]之间并变为张量类型
  • 将测试集转变为张量类型
class MyTrainDataset(Dataset): # 继承 Dataset 类
    def __init__(self, input_path, label_path):
        self.input_path = input_path # 受污染图片所在文件夹
        self.input_path_image = os.listdir(input_path) # 文件夹下的所有图片对象

        self.label_path = label_path # 干净图片所在文件夹
        # self.label_path_image = os.listdir(label_path)
		
		# 定义要对图片进行的变换
        self.transforms = paddle.vision.transforms.Compose([
       		 # 中心裁剪64*64大小作为pacth
            paddle.vision.transforms.CenterCrop([64, 64]), 
            
            # 将读入的数据归一化[0, 1]之间并变为张量类型
            paddle.vision.transforms.ToTensor(), 
            ])

    def __len__(self):
        return len(self.input_path_image) # 返回长度
 
    def __getitem__(self, index):
    	# index 索引对应的受污染图片完整路径
        input_image_path = os.path.join(self.input_path, self.input_path_image[index])
        # 利用PIL.Image 读入图片数据并转换通道结构
        input_image = Image.open(input_image_path).convert('RGB')
        str='.'
        temp=os.path.basename(input_image_path).split(str)
        temp[0]=temp[0].split('_')[0]
        label_path_image=str.join(temp)
        label_image_path = os.path.join(self.label_path, label_path_image)
        label_image = Image.open(label_image_path).convert('RGB')

		# 对读入的图片进行固定的变换
        input = self.transforms(input_image)
        label = self.transforms(label_image)

        return  (input, label) # 返回适合在网络中训练的图片数据

class MyTestDataset(Dataset):
    def __init__(self, input_path): 
        super(MyTestDataset, self).__init__()
        self.input_path = input_path
        self.input_files = os.listdir(self.input_path)
        self.transforms = transforms.Compose([
            # transforms.CenterCrop([128, 128]),# 这行没有必要
            transforms.ToTensor(),
            ])

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, index):
        
        input_image_path = os.path.join(self.input_path, self.input_files[index])
        input_image = Image.open(input_image_path).convert('RGB')
        input = self.transforms(input_image)

        return input

3.4 数据定义

对训练集和测试集分别进行定义。

# 训练图像的路径
input_path = 'rainy_image_dataset/training/rainy_image'
label_path = 'rainy_image_dataset/training/ground_truth'
dataset_train = MyTrainDataset(input_path, label_path)

# 测试图像的路径
input_path_test = 'rainy_image_dataset/testing/rainy_image'
label_path_test = 'rainy_image_dataset/testing/ground_truth'
dataset_val = MyTrainDataset(input_path_test,label_path_test)
dataset_test = MyTestDataset(input_path_test)

4. 模型组网

PRN:Progressive Residual Network
将一个ResNet重复在 T T T个阶段上展开,网络参数在不同阶段重复使用

飞桨高层API实现图像去雨_第3张图片

每个阶段的网络具体包含以下部分:

  • f i n f_{in} fin:Conv+ReLU,接受上个阶段输出的图像和原始雨图的拼接作为输入
  • f r e s f_{res} fres:5个ResBlock,提取深度特征表示
  • f o u t f_{out} fout:Conv,输出去雨结果

飞桨高层API实现图像去雨_第4张图片

每个阶段 T T T的推断过程用以下公式描述:

x t − 0.5 = f i n ( x t − 1 , y ) x^{t-0.5}=f_{in}(x^{t-1},y) xt0.5=fin(xt1,y)

x t = f o u t ( f r e s ( x t − 0.5 ) ) x^{t}=f_{out}(f_{res}(x^{t-0.5})) xt=fout(fres(xt0.5))

class Net(nn.Layer):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2D(6, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv1 = nn.Sequential(
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv2 = nn.Sequential(
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv3 = nn.Sequential(
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv4 = nn.Sequential(
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.res_conv5 = nn.Sequential(
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2D(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.conv = nn.Sequential(
            nn.Conv2D(32, 3, 3, 1, 1),
        )
    def forward(self, input):
        
        x = input
        for i in range(6):# 迭代次数,不改变网络参数量
            
            x = paddle.concat([input, x], 1)
            x = self.conv0(x)
            x = F.relu(self.res_conv1(x) + x)
            x = F.relu(self.res_conv2(x) + x)
            x = F.relu(self.res_conv3(x) + x)
            x = F.relu(self.res_conv4(x) + x)
            x = F.relu(self.res_conv5(x) + x)
            x = self.conv(x)
            x = x + input
            
        return x

model = paddle.Model(Net())
model.summary((-1,3,64,64))
    ---------------------------------------------------------------------------
     Layer (type)       Input Shape          Output Shape         Param #    
    ===========================================================================
       Conv2D-1       [[1, 6, 64, 64]]     [1, 32, 64, 64]         1,760     
        ReLU-1       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-2      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-2       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-3      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-3       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-4      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-4       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-5      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-5       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-6      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-6       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-7      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-7       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-8      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-8       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-9      [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-9       [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-10     [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-10      [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-11     [[1, 32, 64, 64]]     [1, 32, 64, 64]         9,248     
        ReLU-11      [[1, 32, 64, 64]]     [1, 32, 64, 64]           0       
       Conv2D-12     [[1, 32, 64, 64]]      [1, 3, 64, 64]          867      
    ===========================================================================
    Total params: 95,107
    Trainable params: 95,107
    Non-trainable params: 0
    ---------------------------------------------------------------------------
    Input size (MB): 0.05
    Forward/backward pass size (MB): 22.09
    Params size (MB): 0.36
    Estimated Total Size (MB): 22.50
    ---------------------------------------------------------------------------

    {'total_params': 95107, 'trainable_params': 95107}

5. 模型训练

去雨模型中很多都使用了混合损失函数(如MSE+SSIM)和对抗损失。在论文中作者指出,这些损失增加了调整超参的负担。由于渐进式网络结构的存在,单独的MSE或者负SSIM已经足够训练PRN和PReNet达到理想效果。本项目使用的是MSE损失和PRN实现的去雨,先埋一个坑下次我将使用PReNet和负SSIM来实现图片去雨。

大概训练100epoch就可以达到本项目的效果,直接运行下面的代码,会加载我训练好的模型继续训练。你也可以通过注释model.load("output_prn/final")从0开始训练。

learning_rate = 1e-3 
batch_size = 50 # 分批训练数据,每批数据量
epoch = 5 # 训练次数

model = paddle.Model(Net())
model.load("output_prn/final")
model.prepare(optimizer=paddle.optimizer.Adam(learning_rate = learning_rate,parameters=model.parameters()),
                loss = paddle.nn.MSELoss(),
                metrics = paddle.metric.Accuracy())

model.fit(dataset_train,
        epochs=epoch,
        batch_size=batch_size,
        save_dir='output_prn/',
        save_freq = 5,
        num_workers = 6,
        verbose=1)

6 模型测试

直接通过调用predict函数就可以实现对测试集的批量预测,本项目对预测效果进行了展示,从效果可以看出,图片的雨水基本去除,达到了可使用标准。

predict_result = model.predict(dataset_test)
img_test=dataset_test[10]
img_pre=np.array(predict_result[0][10][0])

plt.figure(figsize=(10, 10))
img_test=np.transpose(img_test, (1,2,0))
plt.subplot(1,2, 1)
plt.title('test')
plt.imshow(img_test)

img_pre=np.transpose(img_pre, (1,2,0))
plt.subplot(1,2, 2)
plt.title('predict')
plt.imshow(img_pre)

飞桨高层API实现图像去雨_第5张图片

7 个人介绍

CSDN地址:https://blog.csdn.net/weixin_43267897?spm=1001.2101.3001.5343

Github地址:https://github.com/KHB1698

我在AI Studio上获得黄金等级,点亮5个徽章,来互关呀~ https://aistudio.baidu.com/aistudio/personalcenter/thirdview/791590

你可能感兴趣的:(深度学习,paddle飞桨,深度学习,pytorch,神经网络)