毕业设计 基于Gan的图像去模糊

文章目录

  • 前言
  • 一、GAN是什么?
  • 二、Deblur GAN原理
    • 三、代码参考
    • 四、代码实现
    • 五、实验过程
  • 总结


前言

提示:这里可以添加本文要记录的大概内容:


  • 近几年由于GAN在图像细节生成能力上的优越性开始引起研究者们的注意,并被应用在图像去模糊上,如何提升复原图像的质量以满足实际的应用是目前研究的重点。
    利用目前性能优异的DeblurGAN网络搭配数据集GOPRO训练,实现图像去模糊这一过程。

提示:以下是本篇文章正文内容,下面案例可供参考

一、GAN是什么?

  • 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(GenerativeModel)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

二、Deblur GAN原理

  • DeblurGAN是乌克兰天主教大学的Orest Kupyn等人提出的一种基于GAN方法进行盲运动模糊移除的方法。
  • 受启发于SRGAN与CGAN的成功,将图像模糊移除视为一种特殊的Image2Image任务,DeblurGAN基于wGAN以及内容损失进行训练学习,在SSIM与视觉效果方面,它取得了SOTA性能。
  • 主要贡献:
    提出一种损失与框架,它在运动模糊移除方面取得了SOTA性能;
    提出一种基于随机轨迹的动模糊数据制作方法;
    构建一个新的数据集与评价方法(基于目标检测结果提升)。

三、代码参考

关于DeblurGAN的实现代码,这里给出几个参考:

[1]https://github.com/dongheehand/DeblurGAN-tf

[2]https://github.com/LeeDoYup/DeblurGAN-tf

[3]https://github.com/KupynOrest/DeblurGAN

四、代码实现

mian.py

import tensorflow as tf
from DeblurGAN import DeblurGAN
from mode import *
import argparse

parser = argparse.ArgumentParser()


def str2bool(v):
    return v.lower() in ('true')


## Model specification
parser.add_argument("--n_feats", type=int, default=64)
parser.add_argument("--num_of_down_scale", type=int, default=2)
parser.add_argument("--gen_resblocks", type=int, default=9)
parser.add_argument("--discrim_blocks", type=int, default=3)

## Data specification
parser.add_argument("--train_Sharp_path", type=str, default="./data/train/sharp/")
parser.add_argument("--train_Blur_path", type=str, default="./data/train/blur")
parser.add_argument("--test_Sharp_path", type=str, default="./data/val/val_sharp")
parser.add_argument("--test_Blur_path", type=str, default="./data/val/val_blur")
parser.add_argument("--vgg_path", type=str, default="./vgg19.npy")
parser.add_argument("--patch_size", type=int, default=256)
parser.add_argument("--result_path", type=str, default="./result")
parser.add_argument("--model_path", type=str, default="./model")

## Optimization
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_epoch", type=int, default=200)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--decay_step", type=int, default=150)
parser.add_argument("--test_with_train", type=str2bool, default=True)
parser.add_argument("--save_test_result", type=str2bool, default=True)

## Training or test specification
parser.add_argument("--mode", type=str, default="test")
parser.add_argument("--critic_updates", type=int, default=5)
parser.add_argument("--augmentation", type=str2bool, default=False)
parser.add_argument("--load_X", type=int, default=640)
parser.add_argument("--load_Y", type=int, default=360)
parser.add_argument("--fine_tuning", type=str2bool, default=False)
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--model_save_freq", type=int, default=20)
parser.add_argument("--pre_trained_model", type=str, default="./model/")
parser.add_argument("--test_batch", type=int, default=5)
args = parser.parse_args()

model = DeblurGAN(args)
model.build_graph()

data_loader.py

import tensorflow as tf
import os
 
class dataloader():
    
    def __init__(self, args):
 
        self.channel = 3
 
        self.mode = args.mode
        self.patch_size = args.patch_size
        self.batch_size = args.batch_size
        self.train_Sharp_path = args.train_Sharp_path
        self.train_Blur_path = args.train_Blur_path
        self.test_Sharp_path = args.test_Sharp_path
        self.test_Blur_path = args.test_Blur_path
        self.test_with_train = args.test_with_train
        self.test_batch = args.test_batch
        self.load_X = args.load_X
        self.load_Y = args.load_Y
        self.augmentation = args.augmentation
        
    def build_loader(self):
        
        if self.mode == 'train':
        
            tr_sharp_imgs = sorted(os.listdir(self.train_Sharp_path))
            tr_blur_imgs = sorted(os.listdir(self.train_Blur_path))
            tr_sharp_imgs = [os.path.join(self.train_Sharp_path, ele) for ele in tr_sharp_imgs]
            tr_blur_imgs = [os.path.join(self.train_Blur_path, ele) for ele in tr_blur_imgs]
            train_list = (tr_blur_imgs, tr_sharp_imgs)
            
            self.tr_dataset = tf.data.Dataset.from_tensor_slices(train_list)
            self.tr_dataset = self.tr_dataset.map(self._parse, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._resize, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._get_patch, num_parallel_calls = 4).prefetch(32)
            if self.augmentation:
                self.tr_dataset = self.tr_dataset.map(self._data_augmentation, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.shuffle(32)
            self.tr_dataset = self.tr_dataset.repeat()
            self.tr_dataset = self.tr_dataset.batch(self.batch_size)
            
            if self.test_with_train:
            
                val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
                val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
                val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
                val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
                valid_list = (val_blur_imgs, val_sharp_imgs)
 
                self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
                self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
                self.val_dataset = self.val_dataset.batch(self.test_batch)
 
            iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {}
            self.init_op['tr_init'] = iterator.make_initializer(self.tr_dataset)
            
            if self.test_with_train:
                self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
            
        elif self.mode == 'test':
            
            val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
            val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
            val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
            val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
            valid_list = (val_blur_imgs, val_sharp_imgs)
            
            self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
            self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
            self.val_dataset = self.val_dataset.batch(1)
            
            iterator = tf.data.Iterator.from_structure(self.val_dataset.output_types, self.val_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {}
            self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
             
    def _parse(self, image_blur, image_sharp):
        
        image_blur = tf.read_file(image_blur)
        image_sharp = tf.read_file(image_sharp)
        
        image_blur = tf.image.decode_image(image_blur, channels=self.channel)
        image_sharp = tf.image.decode_image(image_sharp, channels=self.channel)
        
        image_blur = tf.cast(image_blur, tf.float32)
        image_sharp = tf.cast(image_sharp, tf.float32)
        
        return image_blur, image_sharp
    
    def _resize(self, image_blur, image_sharp):
        
        image_blur = tf.image.resize_images(image_blur, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
        image_sharp = tf.image.resize_images(image_sharp, (self.load_Y, self.load_X), tf.image.ResizeMethod.BICUBIC)
        
        return image_blur, image_sharp
 
    def _parse_Blur_only(self, image_blur):
        
        image_blur = tf.read_file(image_blur)
        image_blur = tf.image.decode_image(image_blur, channels=self.channel)
        image_blur = tf.cast(image_blur, tf.float32)
        
        return image_blur
        
    def _get_patch(self, image_blur, image_sharp):
        
        shape = tf.shape(image_blur)
        ih = shape[0]
        iw = shape[1]
        
        ix = tf.random_uniform(shape=[1], minval=0, maxval=iw - self.patch_size + 1, dtype=tf.int32)[0]
        iy = tf.random_uniform(shape=[1], minval=0, maxval=ih - self.patch_size + 1, dtype=tf.int32)[0]
        
        img_sharp_in = image_sharp[iy:iy + self.patch_size, ix:ix + self.patch_size]        
        img_blur_in = image_blur[iy:iy + self.patch_size, ix:ix + self.patch_size]
        
        return img_blur_in, img_sharp_in
    
    def _data_augmentation(self, image_blur, image_sharp):
        
        rot = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        flip_rl = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        flip_updown = tf.random_uniform(shape=[1], minval=0, maxval=3, dtype=tf.int32)[0]
        
        image_blur = tf.image.rot90(image_blur, rot)
        image_sharp = tf.image.rot90(image_sharp, rot)
        
        rl = tf.equal(tf.mod(flip_rl, 2), 0)
        ud = tf.equal(tf.mod(flip_updown, 2), 0)
        
        image_blur = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_blur),
                             false_fn=lambda: image_blur)
        image_sharp = tf.cond(rl, true_fn=lambda: tf.image.flip_left_right(image_sharp),
                              false_fn=lambda: image_sharp)
        
        image_blur = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_blur),
                             false_fn=lambda: image_blur)
        image_sharp = tf.cond(ud, true_fn=lambda: tf.image.flip_up_down(image_sharp),
                              false_fn=lambda: image_sharp)
        
        return image_blur, image_sharp

vgg19.py

import tensorflow as tf
import numpy as np
import time
 
 
VGG_MEAN = [103.939, 116.779, 123.68]
 
 
class Vgg19:
 
    def __init__(self, vgg19_npy_path):
        self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
        print("npy file loaded")
 
    def build(self, rgb):
        """
        load variable from npy to build the VGG
        :param rgb: rgb image [batch, height, width, 3] values scaled [-1, 1]
        """
 
        start_time = time.time()
        print("build vgg19 model started")
        rgb_scaled = ((rgb + 1) * 255.0) / 2.0
 
        # Convert RGB to BGR
        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
        bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]])
 
        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.relu1_1 = self.relu_layer(self.conv1_1, "relu1_1")
        self.conv1_2 = self.conv_layer(self.relu1_1, "conv1_2")
        self.relu1_2 = self.relu_layer(self.conv1_2, "relu1_2")
        self.pool1 = self.max_pool(self.relu1_2, 'pool1')
 
        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.relu2_1 = self.relu_layer(self.conv2_1, "relu2_1")
        self.conv2_2 = self.conv_layer(self.relu2_1, "conv2_2")
        self.relu2_2 = self.relu_layer(self.conv2_2, "relu2_2")
        self.pool2 = self.max_pool(self.relu2_2, 'pool2')
 
        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.relu3_1 = self.relu_layer(self.conv3_1, "relu3_1")
        self.conv3_2 = self.conv_layer(self.relu3_1, "conv3_2")
        self.relu3_2 = self.relu_layer(self.conv3_2, "relu3_2")
        self.conv3_3 = self.conv_layer(self.relu3_2, "conv3_3")
        self.relu3_3 = self.relu_layer(self.conv3_3, "relu3_3")
        self.conv3_4 = self.conv_layer(self.relu3_3, "conv3_4")
        self.relu3_4 = self.relu_layer(self.conv3_4, "relu3_4")
        self.pool3 = self.max_pool(self.relu3_4, 'pool3')
 
        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.relu4_1 = self.relu_layer(self.conv4_1, "relu4_1")
        self.conv4_2 = self.conv_layer(self.relu4_1, "conv4_2")
        self.relu4_2 = self.relu_layer(self.conv4_2, "relu4_2")
        self.conv4_3 = self.conv_layer(self.relu4_2, "conv4_3")
        self.relu4_3 = self.relu_layer(self.conv4_3, "relu4_3")
        self.conv4_4 = self.conv_layer(self.relu4_3, "conv4_4")
        self.relu4_4 = self.relu_layer(self.conv4_4, "relu4_4")
        self.pool4 = self.max_pool(self.relu4_4, 'pool4')
 
        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.relu5_1 = self.relu_layer(self.conv5_1, "relu5_1")
        self.conv5_2 = self.conv_layer(self.relu5_1, "conv5_2")
        self.relu5_2 = self.relu_layer(self.conv5_2, "relu5_2")
        self.conv5_3 = self.conv_layer(self.relu5_2, "conv5_3")
        self.relu5_3 = self.relu_layer(self.conv5_3, "relu5_3")
        self.conv5_4 = self.conv_layer(self.relu5_3, "conv5_4")
        self.relu5_4 = self.relu_layer(self.conv5_4, "relu5_4")
        self.pool5 = self.max_pool(self.conv5_4, 'pool5')
 
        self.data_dict = None
        print(("build vgg19 model finished: %ds" % (time.time() - start_time)))
 
    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
    
    def relu_layer(self, bottom, name):
        return tf.nn.relu(bottom, name=name)
 
    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)
 
            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
 
            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)
                   
            return bias
 
    def get_conv_filter(self, name):
        return tf.constant(self.data_dict[name][0], name="filter")
 
    def get_bias(self, name):
        return tf.constant(self.data_dict[name][1], name="biases")

util.py

from PIL import Image
import numpy as np
import random
import os
 
def image_loader(image_path, load_x, load_y, is_train = True):
    
    imgs = sorted(os.listdir(image_path))
    img_list = []
    for ele in imgs:
        img = Image.open(os.path.join(image_path, ele))
        if is_train:
            img = img.resize((load_x, load_y), Image.BICUBIC)
        img_list.append(np.array(img))
    
    return img_list
 
def data_augument(lr_img, hr_img, aug):
    
    if aug < 4:
        lr_img = np.rot90(lr_img, aug)
        hr_img = np.rot90(hr_img, aug)
    
    elif aug == 4:
        lr_img = np.fliplr(lr_img)
        hr_img = np.fliplr(hr_img)
        
    elif aug == 5:
        lr_img = np.flipud(lr_img)
        hr_img = np.flipud(hr_img)
        
    elif aug == 6:
        lr_img = np.rot90(np.fliplr(lr_img))
        hr_img = np.rot90(np.fliplr(hr_img))
        
    elif aug == 7:
        lr_img = np.rot90(np.flipud(lr_img))
        hr_img = np.rot90(np.flipud(hr_img))
        
    return lr_img, hr_img
 
def batch_gen(blur_imgs, sharp_imgs, patch_size, batch_size, random_index, step, augment=False):
    
    img_index = random_index[step * batch_size: (step + 1) * batch_size]
    
    all_img_blur = []
    all_img_sharp = []
    
    for _index in img_index:
        all_img_blur.append(blur_imgs[_index])
        all_img_sharp.append(sharp_imgs[_index])
    
    blur_batch = []
    sharp_batch = []
    
    for i in range(len(all_img_blur)):
        
        ih, iw, _ = all_img_blur[i].shape
        ix = random.randrange(0, iw - patch_size + 1)
        iy = random.randrange(0, ih - patch_size + 1)
        
        img_blur_in = all_img_blur[i][iy:iy + patch_size, ix:ix + patch_size]
        img_sharp_in = all_img_sharp[i][iy:iy + patch_size, ix:ix + patch_size]        
        
        if augment:
            
            aug = random.randrange(0, 8)
            img_blur_in, img_sharp_in = data_augument(img_blur_in, img_sharp_in, aug)
 
        blur_batch.append(img_blur_in)
        sharp_batch.append(img_sharp_in)
        
    blur_batch = np.array(blur_batch)
    sharp_batch = np.array(sharp_batch)
    
    return blur_batch, sharp_batch

layer.py

import tensorflow as tf
import numpy as np
 
 
def Conv(name, x, filter_size, in_filters, out_filters, strides, padding):
    with tf.variable_scope(name):
        kernel = tf.get_variable('filter', [filter_size, filter_size, in_filters, out_filters], tf.float32,
                                 initializer=tf.random_normal_initializer(stddev=0.01))
        bias = tf.get_variable('bias', [out_filters], tf.float32, initializer=tf.zeros_initializer())
        
        return tf.nn.conv2d(x, kernel, [1, strides, strides, 1], padding=padding) + bias
    
 
def Conv_transpose(name, x, filter_size, in_filters, out_filters, fraction=2, padding="SAME"):
    with tf.variable_scope(name):
        n = filter_size * filter_size * out_filters
        kernel = tf.get_variable('filter', [filter_size, filter_size, out_filters, in_filters], tf.float32,
                                 initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/n)))
        size = tf.shape(x)
        output_shape = tf.stack([size[0], size[1] * fraction, size[2] * fraction, out_filters])
        x = tf.nn.conv2d_transpose(x, kernel, output_shape, [1, fraction, fraction, 1], padding)
        
        return x
 
 
def instance_norm(x, BN_epsilon=1e-3):
    mean, variance = tf.nn.moments(x, axes=[1, 2])
    x = (x - mean) / ((variance + BN_epsilon) ** 0.5)
    return x

DeblurGAN.py

from layer import *
from data_loader import dataloader
from vgg19 import Vgg19
 
 
class DeblurGAN():
    
    def __init__(self, args):
        
        self.data_loader = dataloader(args)
        print("data has been loaded")
 
        self.channel = 3
 
        self.n_feats = args.n_feats
        self.mode = args.mode
        self.batch_size = args.batch_size      
        self.num_of_down_scale = args.num_of_down_scale
        self.gen_resblocks = args.gen_resblocks
        self.discrim_blocks = args.discrim_blocks
        self.vgg_path = args.vgg_path
        
        self.learning_rate = args.learning_rate
        self.decay_step = args.decay_step
        
    def down_scaling_feature(self, name, x, n_feats):
        x = Conv(name=name + 'conv', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats * 2, strides=2, padding='SAME')
        x = instance_norm(x)
        x = tf.nn.relu(x)
        
        return x
    
    def up_scaling_feature(self, name, x, n_feats):
        x = Conv_transpose(name=name + 'deconv', x=x, filter_size=3, in_filters=n_feats,
                           out_filters=n_feats // 2, fraction=2, padding='SAME')
        x = instance_norm(x)
        x = tf.nn.relu(x)
        
        return x
    
    def res_block(self, name, x, n_feats):
        
        _res = x
        
        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = Conv(name=name + 'conv1', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats, strides=1, padding='VALID')
        x = instance_norm(x)
        x = tf.nn.relu(x)
        
        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = Conv(name=name + 'conv2', x=x, filter_size=3, in_filters=n_feats,
                 out_filters=n_feats, strides=1, padding='VALID')
        x = instance_norm(x)
        
        x = x + _res
        
        return x
    
    def generator(self, x, reuse=False, name='generator'):
        
        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            _res = x
            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
            x = Conv(name='conv1', x=x, filter_size=7, in_filters=self.channel,
                     out_filters=self.n_feats, strides=1, padding='VALID')
            # x = instance_norm(name = 'inst_norm1', x = x, dim = self.n_feats)
            x = instance_norm(x)
            x = tf.nn.relu(x)
            
            for i in range(self.num_of_down_scale):
                x = self.down_scaling_feature(name='down_%02d' % i, x=x, n_feats=self.n_feats * (i + 1))
 
            for i in range(self.gen_resblocks):
                x = self.res_block(name='res_%02d' % i, x=x, n_feats=self.n_feats * (2 ** self.num_of_down_scale))
 
            for i in range(self.num_of_down_scale):
                x = self.up_scaling_feature(name='up_%02d' % i, x=x,
                                            n_feats=self.n_feats * (2 ** (self.num_of_down_scale - i)))
 
            x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
            x = Conv(name='conv_last', x=x, filter_size=7, in_filters=self.n_feats,
                     out_filters=self.channel, strides=1, padding='VALID')
            x = tf.nn.tanh(x)
            x = x + _res
            x = tf.clip_by_value(x, -1.0, 1.0)
            
            return x
    
    def discriminator(self, x, reuse=False, name='discriminator'):
        
        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            x = Conv(name='conv1', x=x, filter_size=4, in_filters=self.channel,
                     out_filters=self.n_feats, strides=2, padding="SAME")
            x = instance_norm(x)
            x = tf.nn.leaky_relu(x)
            
            n = 1
            
            for i in range(self.discrim_blocks):
                prev = n
                n = min(2 ** (i+1), 8)
                x = Conv(name='conv%02d' % i, x=x, filter_size=4, in_filters=self.n_feats * prev,
                         out_filters=self.n_feats * n, strides=2, padding="SAME")
                x = instance_norm(x)
                x = tf.nn.leaky_relu(x)
                
            prev = n
            n = min(2 ** self.discrim_blocks, 8)
            x = Conv(name='conv_d1', x=x, filter_size=4, in_filters=self.n_feats * prev,
                     out_filters=self.n_feats * n, strides=1, padding="SAME")
            # x = instance_norm(name = 'instance_norm_d1', x = x, dim = self.n_feats * n)
            x = instance_norm(x)
            x = tf.nn.leaky_relu(x)
            
            x = Conv(name='conv_d2', x=x, filter_size=4, in_filters=self.n_feats * n,
                     out_filters=1, strides=1, padding="SAME")
            x = tf.nn.sigmoid(x)
            
            return x
    
        
    def build_graph(self):
        
        # if self.in_memory:
        self.blur = tf.placeholder(name="blur", shape=[None, None, None, self.channel], dtype=tf.float32)
        self.sharp = tf.placeholder(name="sharp", shape=[None, None, None, self.channel], dtype=tf.float32)
            
        x = self.blur
        label = self.sharp
        
        self.epoch = tf.placeholder(name='train_step', shape=None, dtype=tf.int32)
        
        x = (2.0 * x / 255.0) - 1.0
        label = (2.0 * label / 255.0) - 1.0
        
        self.gene_img = self.generator(x, reuse=False)
        self.real_prob = self.discriminator(label, reuse=False)
        self.fake_prob = self.discriminator(self.gene_img, reuse=True)
        
        epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0)
        
        interpolated_input = epsilon * label + (1 - epsilon) * self.gene_img
        gradient = tf.gradients(self.discriminator(interpolated_input, reuse=True), [interpolated_input])[0]
        GP_loss = tf.reduce_mean(tf.square(tf.sqrt(tf.reduce_mean(tf.square(gradient), axis=[1, 2, 3])) - 1))
        
        d_loss_real = - tf.reduce_mean(self.real_prob)
        d_loss_fake = tf.reduce_mean(self.fake_prob)
        
        self.vgg_net = Vgg19(self.vgg_path)
        self.vgg_net.build(tf.concat([label, self.gene_img], axis=0))
        self.content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(
            self.vgg_net.relu3_3[self.batch_size:] - self.vgg_net.relu3_3[:self.batch_size]), axis=3))
        
        self.D_loss = d_loss_real + d_loss_fake + 10.0 * GP_loss
        self.G_loss = - d_loss_fake + 100.0 * self.content_loss
        
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]
        
        lr = tf.minimum(self.learning_rate, tf.abs(2 * self.learning_rate - (
                self.learning_rate * tf.cast(self.epoch, tf.float32) / self.decay_step)))
        self.D_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.D_loss, var_list=D_vars)
        self.G_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.G_loss, var_list=G_vars)
        
        self.PSNR = tf.reduce_mean(tf.image.psnr(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
        self.ssim = tf.reduce_mean(tf.image.ssim(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
        
        logging_D_loss = tf.summary.scalar(name='D_loss', tensor=self.D_loss)
        logging_G_loss = tf.summary.scalar(name='G_loss', tensor=self.G_loss)
        logging_PSNR = tf.summary.scalar(name='PSNR', tensor=self.PSNR)
        logging_ssim = tf.summary.scalar(name='ssim', tensor=self.ssim)
 
        self.output = (self.gene_img + 1.0) * 255.0 / 2.0
        self.output = tf.round(self.output)
        self.output = tf.cast(self.output, tf.uint8)

mode.py

import os
import tensorflow as tf
from PIL import Image
import numpy as np
import time
import util
 
 
def train(args, model, sess, saver):
    
    if args.fine_tuning:
        saver.restore(sess, args.pre_trained_model)
        print("saved model is loaded for fine-tuning!")
        print("model path is %s" % args.pre_trained_model)
        
    num_imgs = len(os.listdir(args.train_Sharp_path))
    
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('./logs', sess.graph)
    if args.test_with_train:
        f = open("valid_logs.txt", 'w')
    
    epoch = 0
    step = num_imgs // args.batch_size
           
    blur_imgs = util.image_loader(args.train_Blur_path, args.load_X, args.load_Y)
    sharp_imgs = util.image_loader(args.train_Sharp_path, args.load_X, args.load_Y)
        
    while epoch < args.max_epoch:
        random_index = np.random.permutation(len(blur_imgs))
        for k in range(step):
            s_time = time.time()
            blur_batch, sharp_batch = util.batch_gen(blur_imgs, sharp_imgs, args.patch_size,
                                                     args.batch_size, random_index, k)
                
            for t in range(args.critic_updates):
                _, D_loss = sess.run([model.D_train, model.D_loss],
                                     feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})
                    
            _, G_loss = sess.run([model.G_train, model.G_loss],
                                 feed_dict={model.blur: blur_batch, model.sharp: sharp_batch, model.epoch: epoch})
                             
            e_time = time.time()
            
        if epoch % args.log_freq == 0:
            summary = sess.run(merged, feed_dict={model.blur: blur_batch, model.sharp: sharp_batch})
            train_writer.add_summary(summary, epoch)
            if args.test_with_train:
                test(args, model, sess, saver, f, epoch, loading=False)
            print("%d training epoch completed" % epoch)
            print("D_loss : {}, \t G_loss : {}".format(D_loss, G_loss))
            print("Elpased time : %0.4f" % (e_time - s_time))
            # print("D_loss : %0.4f, \t G_loss : %0.4f" % (D_loss, G_loss))
            # print("Elpased time : %0.4f" % (e_time - s_time))
        if (epoch) % args.model_save_freq == 0:
            saver.save(sess, './model/DeblurrGAN', global_step=epoch, write_meta_graph=False)
            
        epoch += 1
 
    saver.save(sess, './model/DeblurrGAN_last', write_meta_graph=False)
    
    if args.test_with_train:
        f.close()
        
        
def test(args, model, sess, saver, file, step=-1, loading=False):
        
    if loading:
 
        import re
        print(" [*] Reading checkpoints...")
        ckpt = tf.train.get_checkpoint_state(args.pre_trained_model)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(args.pre_trained_model, ckpt_name))
            print(" [*] Success to read {}".format(ckpt_name))
        else:
            print(" [*] Failed to find a checkpoint")
     
    blur_img_name = sorted(os.listdir(args.test_Blur_path))
    sharp_img_name = sorted(os.listdir(args.test_Sharp_path))
    
    PSNR_list = []
    ssim_list = []
        
    blur_imgs = util.image_loader(args.test_Blur_path, args.load_X, args.load_Y, is_train=False)
    sharp_imgs = util.image_loader(args.test_Sharp_path, args.load_X, args.load_Y, is_train=False)
 
    if not os.path.exists('./result/'):
        os.makedirs('./result/')
 
    for i, ele in enumerate(blur_imgs):
        blur = np.expand_dims(ele, axis = 0)
        sharp = np.expand_dims(sharp_imgs[i], axis = 0)
        output, psnr, ssim = sess.run([model.output, model.PSNR, model.ssim], feed_dict = {model.blur : blur, model.sharp : sharp})
        if args.save_test_result:
            output = Image.fromarray(output[0])
            split_name = blur_img_name[i].split('.')
            output.save(os.path.join(args.result_path, '%s_sharp.png'%(''.join(map(str, split_name[:-1])))))
 
        PSNR_list.append(psnr)
        ssim_list.append(ssim)
            
    length = len(PSNR_list)
    
    mean_PSNR = sum(PSNR_list) / length
    mean_ssim = sum(ssim_list) / length
    
    if step == -1:
        file.write('PSNR : {} SSIM : {}' .format(mean_PSNR, mean_ssim))
        file.close()
        
    else:
        file.write("{}d-epoch step PSNR : {} SSIM : {} \n".format(step, mean_PSNR, mean_ssim))

五、实验过程

  1. 下载Anoconda,搭建虚拟环境,不建议用cpu,最好用gpu训练,不然你会和我一样弄出一种阴间滤镜。
  2. 使用tensorflow、python,如果你和我一样是小白,就老老实实一步步按照教程弄,不然就是路都不会走,就想着飞天。
  3. 下载安装各种库的时候一定要看清楚版本,有条件的话(不偷懒)就每个代码都分开一个新的虚拟环境,不要嫌麻烦,不然的话你会发现不同版本的库非常不兼容。
  4. 一定要找能运行的代码,有些代码真的奇葩的不行,特别是图像处理的库,比如spicy库,让人无语,还有就随机应变,你永远不会知道代码的作者自己改了什么都不知道。
  5. 如果撞到一个代码运行起来了就一定不要再动它了,谁也不清楚它是怎么运行起来的。和bug一样。
    下面给出实验结果:
    毕业设计 基于Gan的图像去模糊_第1张图片

总结

具体的论文还在写,先记录一下此阶段的研究成果,仅供参考。

你可能感兴趣的:(机器学习,图像处理)