SRCNN 代码解读【model.py】

最近在学习SRCNN,阅读代码做好笔记
代码下载链接https://github.com/tegg89/SRCNN-Tensorflow
下面开始

from utils import (
  read_data, 
  input_setup, 
  imsave,
  psnr,
  merge
)
import time
import os
import cv2
import matplotlib.pyplot as plt
from skimage import data, exposure, img_as_float

import numpy as np
import tensorflow as tf
#定义SRCNN类
class SRCNN(object):

  def __init__(self, 
               sess, 
               image_size=33,
               label_size=21, 
               batch_size=64,
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):
    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size
    self.c_dim = c_dim
    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()
#搭建网络
  def build_model(self):
    self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
    #第一层CNN:对输入图片的特征提取。(9 x 9 x 64卷积核)
    #第二层CNN:对第一层提取的特征的非线性映射(1 x 1 x 32卷积核)
    #第三层CNN:对映射后的特征进行重建,生成高分辨率图像(5 x 5 x 1卷积核)
    #权重
    self.weights = {
      #论文中为提高训练速度的设置 n1=32 n2=16
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    #偏置
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }
    self.pred = self.model()
    # 以MSE作为损耗函数
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
    self.saver = tf.train.Saver()
  #主函数调用(训练或测试)
  def train(self, config):
    if config.is_train:  #判断是否为训练(main传入)
      input_setup(self.sess, config)
    else:
      nx, ny = input_setup(self.sess, config)  
    #训练为checkpoint下train.h5
    #测试为checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")
    train_data, train_label = read_data(data_dir)#读取.h5文件(由测试和训练决定)
    global_step=tf.Variable(0)#定义global_step 它会自动+1
    #通过exponential_decay函数生成学习率
    learning_rate_exp=tf.train.exponential_decay(config.learning_rate , global_step , 1480 , 0.98 , staircase=True)  #每1个Epoch 学习率*0.98   
    #标准反向传播的随机梯度下降
    #self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)#学习率learning rate  使self.loss有最小值
    self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss , global_step=global_step)
    #Adam  替换上面的连续4行
    #self.train_op = tf.train.AdamOptimizer(config.learning_rate).minimize(self.loss, global_step=global_step)
    
    #出现warning : initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
    #tf.initialize_all_variables().run()
    tf.global_variables_initializer().run() #替换掉上句
    counter = 0
    start_time = time.time()
    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")
    #训练
    if config.is_train:     
      print("Training...")
      for ep in range(config.epoch): #迭代次数的循环
        # 以batch为单元
        batch_idxs = len(train_data) // config.batch_size
        for idx in range(0, batch_idxs):
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]
          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})
          if counter % 10 == 0:  #10的倍数step显示
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))
          if counter % 500 == 0:  #500的倍数step存储
            self.save(config.checkpoint_dir, counter)
    #测试
    else:   
      print("Testing...")
      result = self.pred.eval({self.images: train_data, self.labels: train_label}) # 从test.h中来 
      result = merge(result, [nx, ny])
      result = result.squeeze()#除去size为1的维度
      #result= exposure.adjust_gamma(result, 1.07)#调暗一些
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "MySRCNN.bmp")
      imsave(result, image_path)
  def model(self):
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

  def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #再一次确定路径为 checkpoint->srcnn_21下
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),  #文件名为SRCNN.model-迭代次数
                    global_step=step)
  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #路径为checkpoint->srcnn_labelsize(21)
    #加载路径下的模型(.meta文件保存当前图的结构; .index文件保存当前参数名; .data文件保存当前参数值)
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))  #saver.restore()函数给出model.-n路径后会自动寻找参数名-值文件进行加载
        return True
    else:
        return False

你可能感兴趣的:(超分重建)