本文以mnist为例,介绍如何使用TFRecord格式数据和队列进行模型训练和预测。
参考:
1、cifar10
2、https://tensorflow.google.cn/guide/datasets
TFRecord格式数据的制作参见将mnist数据转成原始图片数据再转成TFRecord格式
目录
一、输入数据的解析和预处理
二、定义模型
三、计算损失并定义训练操作
四、模型训练
五、模型验证
六、对单张图片进行预测
def read_mnist_tfrecords(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'img_raw': tf.FixedLenFeature([], tf.string, ''),
'label': tf.FixedLenFeature([], tf.int64, 0)
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int64)
image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1])
return image, label
def inputs(filenames, examples_num, batch_size, shuffle):
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
with tf.name_scope('inputs'):
filename_queue = tf.train.string_input_producer(filenames)
image, label = read_mnist_tfrecords(filename_queue)
image = tf.cast(image, tf.float32)
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(min_fraction_of_examples_in_queue * examples_num)
num_process_threads = 16
if shuffle:
images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
num_threads=num_process_threads,
capacity=min_queue_examples + batch_size * 3,
min_after_dequeue=min_queue_examples)
else:
images, labels = tf.train.batch([image, label], batch_size=batch_size,
num_threads=num_process_threads,
capacity=min_queue_examples + batch_size * 3)
return images, labels
处理之后,返回的是批量的image和对应的label。
def inference(images, training):
with tf.variable_scope('conv1'):
conv1 = tf.layers.conv2d(inputs=images,
filters=32,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # 14*14*32
with tf.variable_scope('conv2'):
conv2 = tf.layers.conv2d(inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # 7*7*64
with tf.variable_scope('fc1'):
pool2_flat = tf.reshape(pool2, [-1, 7*7*64])
fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=training)
with tf.variable_scope('logits'):
logits = tf.layers.dense(inputs=dropout1, units=10) # 使用该值计算交叉熵损失
predict = tf.nn.softmax(logits)
return logits, predict
模型定义采用tf.layers API,返回值中的logits用于计算损失。
def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='cross_entropy')
cross_entropy_loss = tf.reduce_mean(cross_entropy)
return cross_entropy_loss
def train(total_loss, global_step):
num_batches_per_epoch = TRAIN_EXAMPLES_NUM / FLAGS.batch_size
decay_steps = int(num_batches_per_epoch * 10)
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(learning_rate=0.001,
global_step=global_step,
decay_steps=decay_steps,
decay_rate=0.1,
staircase=True)
# opt = tf.train.GradientDescentOptimizer(lr)
# opt = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.99)
opt = tf.train.AdamOptimizer(learning_rate=lr)
grad = opt.compute_gradients(total_loss)
apply_grad_op = opt.apply_gradients(grad, global_step)
return apply_grad_op
学习率初始值为0.001,每过10个epoch衰减一次,变成上次的1/10.
def train():
images, labels = mnist.inputs(['train_img.tfrecords'], mnist.TRAIN_EXAMPLES_NUM,
FLAGS.batch_size, shuffle=True)
global_step = tf.train.get_or_create_global_step()
logits, pred = mnist.inference(images, training=True)
loss = mnist.loss(logits, labels)
train_op = mnist.train(loss, global_step)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.group(
tf.local_variables_initializer(),
tf.global_variables_initializer())
sess.run(init_op)
ckpt = os.path.join(FLAGS.train_dir, 'model.ckpt')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
for i in range(1, FLAGS.max_step + 1):
_, train_loss, predict, label = sess.run([train_op, loss, pred, labels])
# print(predict, '\n', label)
if i % 100 == 0:
print('step: {}, loss: {}'.format(i, train_loss))
# print(predict, '\n', label)
saver.save(sess, ckpt, global_step=i)
coord.request_stop()
coord.join(threads)
训练时通过参数对数据进行shuffle处理。注意调用tf.train.start_queue_runners(sess, coord=coord),否则队列不会启动,程序会一直卡着。
def eval_once(saver, top_k_op):
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('no checkpoint file')
return
coord = tf.train.Coordinator()
try:
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
iter_per_epoch = int(math.ceil(mnist.VALIDATION_EXAMPLES_NUM / FLAGS.batch_size))
total_sample = iter_per_epoch * FLAGS.batch_size
correct_predict = 0
step = 0
while step < iter_per_epoch and not coord.should_stop():
predict = sess.run(top_k_op)
correct_predict += np.sum(predict)
step += 1
precision = correct_predict / total_sample
print('step: {}, model: {}, precision: {}'.format(global_step, ckpt.model_checkpoint_path, precision))
except Exception as e:
print('exception: ', e)
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
def evaluation():
images, labels = mnist.inputs(['./validation_img.tfrecords'], mnist.VALIDATION_EXAMPLES_NUM,
batch_size=FLAGS.batch_size, shuffle=False)
logits, pred = mnist.inference(images, training=False)
top_k_op = tf.nn.in_top_k(logits, labels, 1)
saver = tf.train.Saver()
while True:
eval_once(saver, top_k_op)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
模型验证时不用对数据进行shuffle
def pred(filename, train_dir):
img = cv2.imread(filename, flags=cv2.IMREAD_GRAYSCALE)
img = tf.cast(img, tf.float32)
img = tf.reshape(img, [-1, 28, 28, 1])
logits, predict = mnist.inference(img, training=False)
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(train_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('no checkpoint file')
return
pre = sess.run(predict)
print('model:{}, file:{}, label: {} ({:.2f}%)'.
format(ckpt.model_checkpoint_path, filename, np.argmax(pre[0]), np.max(pre[0]) * 100))
if __name__ == '__main__':
pred('./img_test/2_2098.jpg', './train')
输出:
model:./train\model.ckpt-1000, file:./img_test/2_2098.jpg, label: 2 (96.27%)