SRCNN 代码解读【main.py】

最近在学习SRCNN,阅读代码做好笔记
代码下载链接https://github.com/tegg89/SRCNN-Tensorflow
下面开始

from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os
flags = tf.app.flags
flags.DEFINE_integer("epoch", 2000,"训练多少波")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
#一开始将batch size设为128和64,不仅参数初始loss很大,而且往往一段时间后训练就发散
#batch中每个样本产生梯度竞争可能比较激烈,所以导致了收敛过慢
#后来改回了128
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("image_size", 33, "图像使用的尺寸")
flags.DEFINE_integer("label_size", 21, "label_制作的尺寸")
#学习率文中设置为 前两层1e-4 第三层1e-5
#SGD+指数学习率10-2作为初始
flags.DEFINE_float("learning_rate", 1e-2, "学习率")
flags.DEFINE_integer("c_dim", 1, "图像维度")
flags.DEFINE_integer("scale", 3, "sample的scale大小")
#stride训练采用14,测试采用21
flags.DEFINE_integer("stride", 21 , "步长为14或者21")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "checkpoint directory名字")
flags.DEFINE_string("sample_dir", "sample", "sample directory名字")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#测试
#flags.DEFINE_boolean("is_train", True, "True for training, False for testing")#训练
FLAGS = flags.FLAGS
pp = pprint.PrettyPrinter()
def main(_):
  pp.pprint(flags.FLAGS.__flags)
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  with tf.Session() as sess:
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)
    srcnn.train(FLAGS)
if __name__ == '__main__':
  tf.app.run()

你可能感兴趣的:(超分重建)