RNN代码解读之char-RNN with TensorFlow(train.py)

此工程解读链接(建议按顺序阅读):
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)

前面我们看完了model.py的代码,大家可能会产生一个疑惑,那就是模型的参数是怎么传进去的呢?在训练的时候怎么从以往的checkpoint继续训练呢?其实这些很简单,都在train.py里实现,代码比model里面的代码好理解的多。

和以前一样,我将理解写进了注释,欢迎大家的指正。

#-*-coding:utf-8-*-
from __future__ import print_function
import numpy as np
import tensorflow as tf

import argparse
import time
import os
from six.moves import cPickle

from utils import TextLoader
from model import Model

def main():
    #命令行参数选项
    #每个参数的"help"说明里都有了详细的解释
    parser = argparse.ArgumentParser()
    #输入数据路径
    parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
                       help='data directory containing input.txt')
    #存储模型的路径
    parser.add_argument('--save_dir', type=str, default='save',
                       help='directory to store checkpointed models')
    #rnn的cell内神经元数目
    parser.add_argument('--rnn_size', type=int, default=128,
                       help='size of RNN hidden state')
    #rnn层数
    parser.add_argument('--num_layers', type=int, default=2,
                       help='number of layers in the RNN')
    #rnn类型
    parser.add_argument('--model', type=str, default='lstm',
                       help='rnn, gru, or lstm')
    #batch size。。。
    parser.add_argument('--batch_size', type=int, default=50,
                       help='minibatch size')
    #每个序列的长度
    parser.add_argument('--seq_length', type=int, default=50,
                       help='RNN sequence length')
    #epoch数目
    parser.add_argument('--num_epochs', type=int, default=50,
                       help='number of epochs')
    #保存模型的频率
    parser.add_argument('--save_every', type=int, default=1000,
                       help='save frequency')
    #梯度clip(防止梯度爆炸)
    parser.add_argument('--grad_clip', type=float, default=5.,
                       help='clip gradients at this value')
    #学习率
    parser.add_argument('--learning_rate', type=float, default=0.002,
                       help='learning rate')
    #学习率削减时用到的参数
    parser.add_argument('--decay_rate', type=float, default=0.97,
                       help='decay rate for rmsprop')
    #训练模型时的起始文件
    parser.add_argument('--init_from', type=str, default=None,
                       help="""continue training from saved model at this path. Path must contain files saved by previous training process: 
                            'config.pkl'        : configuration;
                            'chars_vocab.pkl'   : vocabulary definitions;
                            'checkpoint'        : paths to model file(s) (created by tf).
                                                  Note: this file contains absolute paths, be careful when moving files around;
                            'model.ckpt-*'      : file(s) with model definition (created by tf)
                        """)
    args = parser.parse_args()
    train(args)

def train(args):
    #加载数据,解释详见util.py文件
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    #如果要从原来的模型基础上继续训练的话,执行这一程序块
    if args.init_from is not None:
        # check if all necessary files exist 
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
    #这里的.pkl文件我们在model.py里面保存模型的时候还会看到~
    #没错,这里就是把以前保存的模型又调了出来
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    #我们的模型文件,类似caffe里面的train_net.prototxt文件
    #前面的参数传递进去,具体解释详见model.py
    model = Model(args)

    #创建一个session,关于什么是session,这是tf使用中必不可少的一个环节,
    # just google it
    with tf.Session() as sess:
        #所有变量初始化并运行
        tf.initialize_all_variables().run()
        #创建一个saver,便于后面的模型保存和重载
        saver = tf.train.Saver(tf.all_variables())
        # restore model
        #模型重载
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        #e代表每个epoch
        for e in range(args.num_epochs):
            #学习率的dacay,lr = lr*decay_rate^e
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            #状态都初始化
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            #提出不同的batch然后feed进model
            #b代表每个batch
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                #通过feed传递参数
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h
                #运行模型,跑出来一个结果
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                #每训练一步进行输出
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                #当到达保存步数或训练到最后一步时保存模型
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))

if __name__ == '__main__':
    main()

那么现在问题又出现了,我们在代码中看到了很多和input有关系的张量,这些张量的形式是什么样的呢?通过追溯发现在68行通过TextLoader传进来的,在util.py中,那下一章就来看一下这个文件。

参考资料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal

你可能感兴趣的:(RNN,TensorFlow)