本文以mnist为例,介绍如何使用TFRecord格式数据和tf.estimator API进行模型训练和预测。
参考:
1、https://tensorflow.google.cn/tutorials/estimators/cnn
目录
一、数据输入
二、模型定义
三、模型训练和验证
def input_fn(filenames, training):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_data)
if training:
dataset = dataset.shuffle(buffer_size=50000)
dataset = dataset.batch(FLAGS.batch_size)
if training:
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
通过tf.data的API对TFRecord数据进行解析,parse_data的具体实现见上篇:TensorFlow学习实践(二):使用TFRecord格式数据和tf.data API进行模型训练和预测。官方文档(datasets)中有说明,和estimator配合使用时要用dataset.make_one_shot_iterator(),Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator
.
def model_fn(features, labels, mode):
with tf.variable_scope('conv1'):
conv1 = tf.layers.conv2d(inputs=features,
filters=32,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # 14*14*32
with tf.variable_scope('conv2'):
conv2 = tf.layers.conv2d(inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # 7*7*64
with tf.variable_scope('fc1'):
pool2_flat = tf.reshape(pool2, [-1, 7*7*64])
fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
with tf.variable_scope('logits'):
logits = tf.layers.dense(inputs=dropout1, units=10) # 使用该值计算交叉熵损失
predict = tf.nn.softmax(logits)
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
tf.summary.scalar('accuracy', accuracy[1])
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_global_step()
train_op = train(loss, global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {"eval_accuracy": accuracy}
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
根据tf.estimator.ModeKeys的值,返回不同的结果
def train():
my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5)
mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir,
config=my_checkpoint_config)
tensor_to_log = {'probabilities': 'softmax_tensor'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True),
# hooks=[logging_hook],
steps=FLAGS.max_step)
eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False))
print(eval_results)
训练结果:
...
INFO:tensorflow:loss = 4.674489e-05, step = 11500 (0.549 sec)
INFO:tensorflow:Saving checkpoints for 11600 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 181.878
INFO:tensorflow:loss = 0.0001392595, step = 11600 (0.550 sec)
INFO:tensorflow:Saving checkpoints for 11700 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 73.6977
INFO:tensorflow:loss = 1.4009732e-05, step = 11700 (1.356 sec)
INFO:tensorflow:Saving checkpoints for 11800 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 179.256
INFO:tensorflow:loss = 0.00017982977, step = 11800 (0.558 sec)
INFO:tensorflow:Saving checkpoints for 11900 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 182.132
INFO:tensorflow:loss = 0.00027710196, step = 11900 (0.549 sec)
INFO:tensorflow:Saving checkpoints for 12000 into ./train\model.ckpt.
INFO:tensorflow:Loss for final step: 4.0867322e-05.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-08-29-13:18:34
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./train\model.ckpt-12000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-08-29-13:18:34
INFO:tensorflow:Saving dict for global step 12000: eval_accuracy = 0.9934, global_step = 12000, loss = 0.0540578
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 12000: ./train\model.ckpt-12000
{'eval_accuracy': 0.9934, 'loss': 0.0540578, 'global_step': 12000}
Process finished with exit code 0
训练了12000步,batch_size是128,训练结束后对验证集进行验证,准确率99.34%。
这里是在训练结束后再进行验证的,我想每训练100步,对验证集验证一次,看下准确率,但是从官方文档没看到怎么操作,文档中说明,每重复一次tf.estimator.Estimator().train(),训练步数会累积,也就是会在上轮结果上继续训练。
即,下面两种写法训练的结果相同:
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
和
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=300)
于是,修改训练部分代码如下:
def train():
my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5)
mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir,
config=my_checkpoint_config)
tensor_to_log = {'probabilities': 'softmax_tensor'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100)
for i in range(FLAGS.max_step//100):
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True),
# hooks=[logging_hook],
steps=100)
eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False))
print(eval_results)
训练结果:
...
{'eval_accuracy': 0.9922, 'loss': 0.06816594, 'global_step': 11300}
{'eval_accuracy': 0.9922, 'loss': 0.068535455, 'global_step': 11400}
{'eval_accuracy': 0.9922, 'loss': 0.06853329, 'global_step': 11500}
{'eval_accuracy': 0.9924, 'loss': 0.06850766, 'global_step': 11600}
{'eval_accuracy': 0.9922, 'loss': 0.068637684, 'global_step': 11700}
{'eval_accuracy': 0.992, 'loss': 0.069534324, 'global_step': 11800}
{'eval_accuracy': 0.9916, 'loss': 0.07054804, 'global_step': 11900}
{'eval_accuracy': 0.9916, 'loss': 0.07171986, 'global_step': 12000}
Process finished with exit code 0
这样做会降低效率,查看源码可以发现,每次train最后都会调用MonitoredTrainingSession建立session训练,还会restore上一次的训练结果,暂时没找到好的方法能够在训练过程中对验证集进行验证。
https://github.com/buptlj/learn_tf