此工程解读链接(建议按顺序阅读):
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)
前面我们看完了model.py的代码,大家可能会产生一个疑惑,那就是模型的参数是怎么传进去的呢?在训练的时候怎么从以往的checkpoint继续训练呢?其实这些很简单,都在train.py里实现,代码比model里面的代码好理解的多。
和以前一样,我将理解写进了注释,欢迎大家的指正。
#-*-coding:utf-8-*-
from __future__ import print_function
import numpy as np
import tensorflow as tf
import argparse
import time
import os
from six.moves import cPickle
from utils import TextLoader
from model import Model
def main():
#命令行参数选项
#每个参数的"help"说明里都有了详细的解释
parser = argparse.ArgumentParser()
#输入数据路径
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
help='data directory containing input.txt')
#存储模型的路径
parser.add_argument('--save_dir', type=str, default='save',
help='directory to store checkpointed models')
#rnn的cell内神经元数目
parser.add_argument('--rnn_size', type=int, default=128,
help='size of RNN hidden state')
#rnn层数
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
#rnn类型
parser.add_argument('--model', type=str, default='lstm',
help='rnn, gru, or lstm')
#batch size。。。
parser.add_argument('--batch_size', type=int, default=50,
help='minibatch size')
#每个序列的长度
parser.add_argument('--seq_length', type=int, default=50,
help='RNN sequence length')
#epoch数目
parser.add_argument('--num_epochs', type=int, default=50,
help='number of epochs')
#保存模型的频率
parser.add_argument('--save_every', type=int, default=1000,
help='save frequency')
#梯度clip(防止梯度爆炸)
parser.add_argument('--grad_clip', type=float, default=5.,
help='clip gradients at this value')
#学习率
parser.add_argument('--learning_rate', type=float, default=0.002,
help='learning rate')
#学习率削减时用到的参数
parser.add_argument('--decay_rate', type=float, default=0.97,
help='decay rate for rmsprop')
#训练模型时的起始文件
parser.add_argument('--init_from', type=str, default=None,
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
'config.pkl' : configuration;
'chars_vocab.pkl' : vocabulary definitions;
'checkpoint' : paths to model file(s) (created by tf).
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)
""")
args = parser.parse_args()
train(args)
def train(args):
#加载数据,解释详见util.py文件
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
args.vocab_size = data_loader.vocab_size
# check compatibility if training is continued from previously saved model
#如果要从原来的模型基础上继续训练的话,执行这一程序块
if args.init_from is not None:
# check if all necessary files exist
assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
ckpt = tf.train.get_checkpoint_state(args.init_from)
assert ckpt,"No checkpoint found"
assert ckpt.model_checkpoint_path,"No model path found in checkpoint"
# open old config and check if models are compatible
with open(os.path.join(args.init_from, 'config.pkl')) as f:
saved_model_args = cPickle.load(f)
need_be_same=["model","rnn_size","num_layers","seq_length"]
for checkme in need_be_same:
assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
# open saved vocab/dict and check if vocabs/dicts are compatible
with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
saved_chars, saved_vocab = cPickle.load(f)
assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
#这里的.pkl文件我们在model.py里面保存模型的时候还会看到~
#没错,这里就是把以前保存的模型又调了出来
with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
cPickle.dump(args, f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
cPickle.dump((data_loader.chars, data_loader.vocab), f)
#我们的模型文件,类似caffe里面的train_net.prototxt文件
#前面的参数传递进去,具体解释详见model.py
model = Model(args)
#创建一个session,关于什么是session,这是tf使用中必不可少的一个环节,
# just google it
with tf.Session() as sess:
#所有变量初始化并运行
tf.initialize_all_variables().run()
#创建一个saver,便于后面的模型保存和重载
saver = tf.train.Saver(tf.all_variables())
# restore model
#模型重载
if args.init_from is not None:
saver.restore(sess, ckpt.model_checkpoint_path)
#e代表每个epoch
for e in range(args.num_epochs):
#学习率的dacay,lr = lr*decay_rate^e
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
#状态都初始化
data_loader.reset_batch_pointer()
state = sess.run(model.initial_state)
#提出不同的batch然后feed进model
#b代表每个batch
for b in range(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
#通过feed传递参数
feed = {model.input_data: x, model.targets: y}
for i, (c, h) in enumerate(model.initial_state):
feed[c] = state[i].c
feed[h] = state[i].h
#运行模型,跑出来一个结果
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
end = time.time()
#每训练一步进行输出
print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
.format(e * data_loader.num_batches + b,
args.num_epochs * data_loader.num_batches,
e, train_loss, end - start))
#当到达保存步数或训练到最后一步时保存模型
if (e * data_loader.num_batches + b) % args.save_every == 0\
or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
print("model saved to {}".format(checkpoint_path))
if __name__ == '__main__':
main()
那么现在问题又出现了,我们在代码中看到了很多和input有关系的张量,这些张量的形式是什么样的呢?通过追溯发现在68行通过TextLoader传进来的,在util.py中,那下一章就来看一下这个文件。
参考资料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal