SRCNN 图像超分辨率重建(tf2)

文章目录

  • 前言
  • 一、SRCNN
  • 二、SRCNN 实现
    • 1.模型的搭建
    • 2.生成训练数据
    • 3.训练过程:
    • 4.测试过程
  • 总结


前言

把由放大缩小的引起的导致分辨率低的图像,转换成为 分辨率高的图像。更加关注的是重构图片过程中,填充新的像素。SRCNN 呢也是将深度学习用于图像重建的鼻祖,网络结构非常简单,于是我决定来复现一下它。
代码链接:https://github.com/jiantenggei/SRCNN-Keras (包含所有资源)

一、SRCNN

SRCNN 图像超分辨率重建(tf2)_第1张图片
SRCNN 的网络结构特别简单,首先将一张低分辨率的图像作为输入,通过两个卷积后,还原成为高质量的图片。网络卷积运算时,保持特征图的大小和重构图片大小一致。且网络中没有线性连接。也就只有三层卷积~。

二、SRCNN 实现

首先在这里简述一下SRCNN 网络的训练和测试流程:
1.先将图片缩小后,再放大,制作不清晰的图片作为训练数据
2.将未处理过的原图片作为训练时的标签。
3.将图片和标签放到网络中训练。
4.测试模型时,将一张不清晰的图片输入到训练好的网络,生成的图片与不清晰的图片计算峰值信噪比。

1.模型的搭建

代码如下:

from keras.models import Sequential, model_from_json
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation


def built_model(input_shape=(33, 33, 1)):
    model = Sequential()
    model.add(Conv2D(filters=64, kernel_size=9,
                    padding='same', input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(Conv2D(32, 1, padding='same'))
    model.add(Activation('relu'))
    model.add(Conv2D(input_shape[2], 5, padding='same'))
    return model

if __name__ == '__main__':
    model = built_model()
    model.summary()

最后一层把图片从多通道特征图,还原成和输入通道一致。

2.生成训练数据

1.先将图片缩小后,再放大,制作不清晰的图片作为训练数据
2.将未处理过的原图片作为训练时的标签。
代码如下:

def load_train(image_size=33, stride=33, scale=3,dirname=r'dataset\train'):
    dir_list = os.listdir(dirname)
    images = [cv2.cvtColor(cv2.imread(os.path.join(dirname,img)),cv2.COLOR_BGR2GRAY) for img in dir_list]
    #==========================
    #这里判断采样步长 是否能被整除
    #=========================
    images = [img[0:img.shape[0]-np.remainder(img.shape[0],scale),0:img.shape[1]-np.remainder(img.shape[1],scale)] for img in images]

    trains = images.copy()
    labels = images.copy()
    #========================================
    #对train image 进行缩小 放大 产生不清晰的图像
    #========================================
    trains = [cv2.resize(img, None, fx=1/scale, fy=1/scale, interpolation=cv2.INTER_CUBIC) for img in trains]
    trains = [cv2.resize(img, None, fx=scale/1, fy=scale/1, interpolation=cv2.INTER_CUBIC) for img in trains]

    sub_trains = []
    sub_labels = []
    
    #========================================
    # 通过采样形成标签 和训练数据,
    # 一张 图片 通过采样,可以分成很多个图像块,作为训练数据,丰富样本
    #========================================
    for train, label in zip(trains, labels):
        v, h = train.shape
        print(train.shape)
        for x in range(0,v-image_size+1,stride):
            for y in range(0,h-image_size+1,stride):
                sub_train = train[x:x+image_size,y:y+image_size]
                sub_label = label[x:x+image_size,y:y+image_size]
                sub_train = sub_train.reshape(image_size,image_size,1)
                sub_label = sub_label.reshape(image_size,image_size,1)
                sub_trains.append(sub_train)
                sub_labels.append(sub_label)
    #========================================
    #编码为numpy array
    #========================================
    sub_trains = np.array(sub_trains)
    sub_labels = np.array(sub_labels)
    return sub_trains, sub_labels

def load_test(scale=3,dirname=r'dataset\test'):
    #========================================
    # 生成测试数据的方式与训练数据相同
    # pre_tests 是用来保存缩小后的图片
    #========================================
    dir_list = os.listdir(dirname)
    images = [cv2.cvtColor(cv2.imread(os.path.join(dirname,img)),cv2.COLOR_BGR2GRAY) for img in dir_list]
    images = [img[0:img.shape[0]-np.remainder(img.shape[0],scale),0:img.shape[1]-np.remainder(img.shape[1],scale)] for img in images]

    tests = images.copy()
    labels = images.copy()
    
    pre_tests = [cv2.resize(img, None, fx=1/scale, fy=1/scale, interpolation=cv2.INTER_CUBIC) for img in tests]
    tests = [cv2.resize(img, None, fx=scale/1, fy=scale/1, interpolation=cv2.INTER_CUBIC) for img in pre_tests]
    
    pre_tests = [img.reshape(img.shape[0],img.shape[1],1) for img in pre_tests]
    tests = [img.reshape(img.shape[0],img.shape[1],1) for img in tests]
    labels = [img.reshape(img.shape[0],img.shape[1],1) for img in labels]

    return pre_tests, tests, labels

注意:代码中采样过程(三个for 训练处) 是将一张图片,截取一个个小的区域,这样一张图片就可以产生成多个数据,弥补训练样本不足的问题。

3.训练过程:

代码如下:

from tensorflow.python.keras.saving.model_config import model_from_config
from model import built_model
from utils import load_train
from keras.optimizers import Adam

def train():
    # ==========================
    # input_shape 输入图片大小
    # stride 原图片采样间隔
    # batch_size epochs learning_rate
    #============================
    input_shape = (33, 33, 1)
    stride = 14
    batch_size = 64
    epochs=100
    learning_rate=0.001

    # 定义模型
    srcnn_model = built_model(input_shape=input_shape)
    srcnn_model.load_weights(r'model\srcnn_weight.hdf5')
    srcnn_model.summary()

    # 加载数据
    X_train, Y_train = load_train(image_size=input_shape[0], stride=stride)
    print(X_train.shape, Y_train.shape)
    optimizer = Adam(lr=learning_rate)
    srcnn_model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['accuracy'])
    srcnn_model.fit(X_train,Y_train,epochs=epochs,batch_size=batch_size)
    srcnn_model.save(r'model/srcnn.h5')

if __name__ == '__main__':
    train()

这里计算损失用均方差,因为输入和输出都是大小一致的图片~ 只是分辨率不同。

4.测试过程

代码如下:

from model import built_model
import os
from utils import load_test,psnr
import cv2
def test():
    input_shape = (None, None, 1)
    scale = 3
    srcnn_model = built_model(input_shape=input_shape)
    srcnn_model.load_weights(r'model\srcnn_weight.hdf5')

    X_pre_test, X_test, Y_test = load_test(scale=scale)

    predicted_list = []

    for img in X_test:
        img = img.reshape(1,img.shape[0],img.shape[1],1)
        predicted=srcnn_model.predict(img)
        predicted_list.append(predicted.reshape(predicted.shape[1],predicted.shape[2],1))
    n_img = len(predicted_list)
    dirname = './result'
    for i in range(n_img):
        imgname = 'image{:02}'.format(i)
        cv2.imwrite(os.path.join(dirname,imgname+'_original.bmp'), X_pre_test[i])
        cv2.imwrite(os.path.join(dirname,imgname+'_input.bmp'), X_test[i])
        cv2.imwrite(os.path.join(dirname,imgname+'_answer.bmp'), Y_test[i])
        cv2.imwrite(os.path.join(dirname,imgname+'_predicted.bmp'), predicted_list[i])
          # 计算峰值信噪比
        answer = psnr(X_test[i],predicted_list[i])
        print(imgname+"_psnr:",answer)

if __name__ == '__main__':
    test()

X_test 存放的是不清晰的图片,用于和预测结果计算计算峰值信噪比。

总结

各种网络各种功能,喂数据的方式不同,计算损失的方式不同罢了~

你可能感兴趣的:(深度学习入门,深度学习,pytorch,神经网络)