Tensorflow2.0 十行代码实现SRCNN

超分辨率技术(Super-Resolution, SR)是指从观测到的低分辨率图像重建出相应的高分辨率图像,在监控设备、卫星图像和医学影像等领域都有重要的应用价值。

SRCNN是深度学习用在超分辨率重建上的开山之作(Image Super-Resolution Using Deep Convolutional Networks),SRCNN的网络结构非常简单,仅仅用了三个卷积层,网络结构如下图所示。


SRCNN首先使用双三次(bicubic)插值将低分辨率图像放大成目标尺寸,接着通过三层卷积网络拟合非线性映射,最后输出高分辨率图像结果。本文中,作者将三层卷积的结构解释成三个步骤:图像块的提取和特征表示,特征非线性映射和最终的重建。

三个卷积层使用的卷积核的大小分为为9x9,,1x1和5x5,前两个的输出特征个数分别为64和32。用Timofte数据集(包含91幅图像)和ImageNet大数据集进行训练。使用均方误差(Mean Squared Error, MSE)作为损失函数,有利于获得较高的PSNR。

以上引自(有改动):从SRCNN到EDSR,总结深度学习端到端超分辨率方法发展历程

 SRCNN代码使用Cifar10数据集,由32*32*3无损放大为128*128*3,以下全部代码见Github。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(128, 128, 3), name='img')

x = layers.Conv2D(
    filters=64,             # 卷积层神经元(卷积核)数目
    kernel_size=9,          # 感受野大小
    padding='same',         # padding策略(vaild 或 same)
    activation=tf.nn.relu   # 激活函数
)(inputs)

x = layers.Conv2D(
    filters=32,      
    kernel_size=1,
    padding='same',  
    activation=tf.nn.relu
)(x)

outputs = layers.Conv2D(
    filters=3,
    kernel_size=5,          
    padding='same'          # 不设置激活函数
)(x)

model = keras.Model(inputs=inputs, outputs=outputs, name='SRCNN_model')# 通过在图层图中指定其输入和输出来创建一个model
  
model.summary() # 查看模型摘要,需要模型built(实例化)后调用
Model: "SRCNN_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
img (InputLayer)             [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 128, 64)      15616     
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 32)      2080      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 128, 128, 3)       2403      
=================================================================
Total params: 20,099
Trainable params: 20,099
Non-trainable params: 0
_________________________________________________________________

加载数据

import cv2 as cv
import numpy as np

'''
CIFAR10 had't 128*128*3 images use bicubic alternatived,and bicubic use nearest alternatived.
X_: image applied bicubic interpolation (low-resolution),(50000, 128, 128, 3)
y_: image with original resolution (high-resolution),(10000, 128, 128, 3)
'''

ishape = 128

# load data
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

# 缩小数据集,控制内存占用(以下win10,8G内存可用)
train_image = train_images[0:10000]
test_image  = test_images[0:1000]

X_train = np.array([cv.resize(i,(ishape,ishape), interpolation=cv.INTER_NEAREST) for i in train_image]) / 255.
X_test  = np.array([cv.resize(i,(ishape,ishape), interpolation=cv.INTER_NEAREST) for i in test_image]) / 255.

y_train = np.array([cv.resize(i,(ishape,ishape), interpolation=cv.INTER_CUBIC) for i in train_image]) / 255.
y_test  = np.array([cv.resize(i,(ishape,ishape), interpolation=cv.INTER_CUBIC) for i in test_image]) / 255.

开始训练 

model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss='mse',       
              metrics=['mae'])    # 编译

history = model.fit(X_train, y_train,
                    batch_size=64,
                    epochs=3,
                    validation_split=0.2)# 训练

test_scores = model.evaluate(X_test, y_test, verbose=2) # 评估

print('Test loss:', test_scores[0])
print('Test mae:', test_scores[1])

# Save entire model to a HDF5 file
model.save('SRCNN.h5')

 开始应用

from matplotlib import pyplot as plt

ishape = 128

#加载放大图像并显示
img = cv.imread('automobile.png')

img = cv.cvtColor(img, cv.COLOR_BGR2RGB)

img = cv.resize(img,(ishape,ishape), interpolation=cv.INTER_NEAREST)# (36,36,3)->(128,128,3)

plt.imshow(img)

plt.xticks([]), plt.yticks([])

plt.show()

 原图像:

Tensorflow2.0 十行代码实现SRCNN_第1张图片

img = np.reshape(img,(1,ishape,ishape,3)) / 255.

# 处理图像超分辨率
img_SR = model.predict(img)

plt.imshow(img_SR[0])

plt.xticks([]), plt.yticks([])

plt.show()

SRCNN处理后图像: 

Tensorflow2.0 十行代码实现SRCNN_第2张图片

你可能感兴趣的:(CV,Tensorflow,Python,tensorflow,计算机视觉,深度学习)