一直想将图片制作成tfrecords文件,然后在模型中运行一下。最初想用的数据集是mnist,但是跑的过程中一直出现问题。找到这一篇知乎上的博客,写的非常不错。
原博客地址:https://zhuanlan.zhihu.com/p/32490882
其代码地址:https://github.com/HelloSangShen/Cat-vs-Dog/
猫狗数据集:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA 密码:dmp4
本文以kaggle的猫狗大战为例,完整地描述使用TensorFlow进行一次完整CNN训练的每个步骤。首先介绍如何将图片转为TFRecords文件,然后介绍如何读取该文件的数据并且输入给我们的网络进行训练,并且会展示如何通过hook来监测网络训练的情况(这里没有使用TensorBoard)。最后会简单解读一下MonitoredTrainingSession的使用方法。
3.1 数据处理
有过实践的小伙伴应该能感受到,当有了TensorFlow、PyTorch这样优秀的框架后,构造一个神经网络、进行训练、计算损失函数、预测等都变的相对容易许多。但是数据的预处理仍然是一个相对棘手的问题,尤其是在较大数据集上进行训练时,不能总是使用占位符(placeholder)和feed dict进行数据加载,而TensorFlow提供了另外一种加载方式。这部分就着重介绍如何将图片数据存储为TFRecords,并且通过队列读取给我们的网络。因为网上有非常多介绍TFRecords原理的文章,我这里就不细说了,只给出详细的代码和注释,示范一下如何处理。
def read_images(path):
"""从源文件/路径读取图像
参数:
path: 图像所在的路径即文件夹名称
返回:
返回一个带有所有图像、标签和总数信息的对象
images: 所有的图像数据
labels: 所有标签
num: 数目
"""
# 获取文件夹内所有图像文件的文件名和总数
filenames = next(walk(path))[2]
num_file = len(filenames)
# 初始化图像和标签
images = np.zeros((num_file, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL), dtype=np.uint8)
labels = np.zeros((num_file, ), dtype=np.uint8)
# 遍历读取文件
for index, filename in enumerate(filenames):
# 读取单张图像,并且修改为自定义尺寸
img = imread(join(path, filename))
img = imresize(img, (IMAGE_HEIGHT, IMAGE_HEIGHT))
images[index] = img
# TO DO
# 这里通过文件名获取标签信息,猫狗大战问题中只有两类,故只有0和1
# 可以根据自己的需要进行修改
# 注意:这里不是one-hot编码
if filename[0:3] == 'cat':
labels[index] = int(0)
else:
labels[index] = int(1)
if index % 1000 == 0:
print("Reading the %sth image" % index)
# 创建一个类,该类携带图像、标签和总数信息
class ImgData(object):
pass
result = ImgData()
result.images = images
result.labels = labels
result.num = num_file
return result
通过上述函数,我们可以读取到文件夹内所有的图片。接下来,我们要把这些图片转为TFRecords文件。
def convert(data, destination):
"""将图片存储为.tfrecords文件
参数:
data: 上述函数返回的ImageData对象
destination: 目标文件名
"""
images = data.images
labels = data.labels
num_examples = data.num
# 存储的文件名
filename = destination
# 使用TFRecordWriter来写入数据
writer = tf.python_io.TFRecordWriter(filename)
# 遍历图片
for index in range(num_examples):
# 转为二进制
image = images[index].tostring()
label = labels[index]
# tf.train下有Feature和Features,需要注意其区别
# 层级关系为Example->Features->Feature
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
# 写入
writer.write(example.SerializeToString())
writer.close()
这两个函数就可以把我们的数据集图片全都写入一个.tfrecords文件。如果文件过大,可以写入多个文件。
下面介绍如何从tfrecords文件中批量读取图片和标签。
def read_and_decode(filename_queue):
"""读取.tfrecords文件
参数:
filename_queue: 文件名, 一个列表
返回:
img, label: **单张图片和对应标签**
"""
# 创建一个图节点,该节点负责数据输入
filename_queue = tf.train.string_input_producer([filename_queue])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 解析单个example
features = tf.parse_single_example(serialized_example, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image'], tf.uint8)
image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
image = tf.cast(image, tf.float32)
label = tf.cast(features['label'], tf.int64)
return image, label
我们将数据读取的功能进行封装,代码如下:
def distorted_input(filename, batch_size):
"""建立一个乱序的输入
参数:
filename: tfrecords文件的文件名. 注:该文件名仅为文件的名称,不包含路径和后缀
batch_size: 每次读取的batch size
返回:
images: 一个4D的Tensor. size: [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]
labels: 1D的标签. size: [batch_size]
"""
# 完整文件名,文件存储在同一路径下的tfrecords文件夹下,名为filename.tfrecords
filename = './tfrecords/' + filename + '.tfrecords'
# 如果路径下没有该文件,说明没有进行转换工作,则将图片转为tfrecords文件
if not os.path.exists(filename):
print('Transfer images to TF_Records')
raw_data = read_images(FLAGS.raw_data_path)
convert(raw_data, filename)
print('End transfering')
image, label = read_and_decode(filename)
# 乱序读入一个batch
images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
num_threads=16, capacity=3000, min_after_dequeue=1000)
return images, labels
以上,我们就完成了数据的读取部分了。下面用一段代码进行测试。
images, labels = catdog_input.distorted_input(FLAGS.tfrecords_file_name, batch_size=4)
# from matplotlib import pyplot as plt
fig = plt.figure()
a = fig.add_subplot(221)
b = fig.add_subplot(222)
c = fig.add_subplot(223)
d = fig.add_subplot(224)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
# 开启文件读取队列,开启后才能开始读取数据
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
img, label = sess.run([images, labels])
a.imshow(img[0])
a.axis('off')
b.imshow(img[1])
b.axis('off')
c.imshow(img[2])
c.axis('off')
d.imshow(img[3])
d.axis('off')
plt.show()
coord.request_stop()
coord.join(threads)
通过这个简单的测试程序就可以可视化四张图片出来。
3.2 模型
这里,我们使用VGG-16模型来做测试。TensorFlow在搭建网络上非常方便,这里就不给详细代码了(可以参考http://www.cs.toronto.edu/~frossard/post/vgg16/),读者可以在文末的GitHub链接上找到相关代码。
对于准确率、损失函数等,我们参考TensorFlow教程中Cifar10训练的源代码进行实现,将这些函数均封装起来。
def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
# 注意:我们上面定义的标签不是one-hot编码,故这里调用的是sparse方法
# 如果使用one-hot,调用softmax_cross_entropy_with_logits即可
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
loss = tf.reduce_mean(cross_entropy, name='cross_entropy')
return loss
def accuracy(logits, labels):
# 将labels转为one-hot编码进行计算
labels = tf.one_hot(labels, NUM_CLASS)
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
return accuracy
def train(loss):
train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
return train_op
至此,我们的模型就搭建好了。接下来就是训练步骤。
3.3 训练
下面的train()
函数也是参照cifar10的源码进行实现的。
def train():
# 因为要使用StopAtStepHook,故global_step是必须的
global_step = tf.train.get_or_create_global_step()
# 输入
images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)
logits = catdog_model.inference(images)
loss = catdog_model.loss(logits, labels)
# accuracy = catdog_model.accuracy(logits, labels)
train_op = catdog_model.train(loss)
class _LoggerHook(tf.train.SessionRunHook):
"""
该类用来打印训练信息
"""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
# 该函数在训练运行之前自动调用
# 在这里返回所有你想在运行过程中查看到的信息
# 以list的形式传递,如:[loss, accuracy]
return tf.train.SessionRunArgs(loss)
def after_run(self, run_context, run_values):
# 打印信息的步骤间隔
display_step = 10
if self._step % display_step == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
# results返回的就是上面before_run()的返回结果,上面是loss故这里是loss
# 若输入的是list,返回也是一个list
loss = run_values.results
# 每秒使用的样本数
examples_per_sec = display_step * BATCH_SIZE / duration
# 每batch使用的时间
sec_per_batch = float(duration / display_step)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
tf.train.NanTensorHook(loss),
_LoggerHook()], # 将上面定义的_LoggerHook传入
config=tf.ConfigProto(
log_device_placement=False)) as sess:
coord = tf.train.Coordinator()
# 开启文件读取队列
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
while not sess.should_stop():
sess.run(train_op)
coord.request_stop()
coord.join(threads)
上面就是在猫狗大战数据集上进行的一个完整的图片数据预处理、数据读取、搭建网络、训练并监测的过程。
3.4 评估
因为实验室设备暂时有点问题,没法训练,故现在没法给出结果,以后训练出结果后再来更新吧。
3.5 关于MonitoredTrainingSession
我们在上面的训练中用到了tf.train.MonitoredTrainingSession(...)
。查阅了一下官方文档,该类继承自MonitoredSession类。我们先看看这个父类,官方文档中给了一段如下示例代码 :
saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
hooks=[saver_hook, summary_hook]) as sess:
while not sess.should_stop():
sess.run(train_op)
首先,当MonitoredSession初始化的时候,会按顺序执行下面操作:
begin()
函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook
里面的_step属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。scaffold.finalize()
初始化计算图Scaffold
提供的操作(op)来初始化模型hook.after_create_session()
然后,当run()
函数运行的时候,按顺序执行下列操作:
hook.before_run()
session.run()
hook.after_run()
session.run()
的结果最后,当调用close()退出时,按顺序执行下列操作:
hook.end()
需要注意的是:该类不是一个tf.Session()
,因为它不能被设置为默认会话,不能被传递给saver.save,也不能被传递给tf.train.start_queue_runners
,这也解释了为什么在开启会话后我们必须手动调用tf.train.start_queue_runners()
而MonitoredTrainingSession则比起父类多了许多其他的参数,可以在官方文档获取各参数的说明,这里我们不详细说。但是根据其父类的执行说明,我们就可以很容易理解上面train()
函数中发生了什么。
首先,我们先将计算图的各个节点/操作定义好,构成了一个计算图。然后开启了一个MonitoredTrainingSession来初始化/注册我们的图和其他信息。其中,我们给其传递了3个hook:
tf.train.StopAtStepHook(last_step)
,该hook主要是在训练到特定步数后即请求停止,使用该hook必须要预先定义一个tf.train.get_or_create_global_step()
。否则会抛出运行时错误,见源码:def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use StopAtStepHook.")
tf.train.NanTensorHook(loss)
,该hook用来监测loss,若loss的结果为NaN,抛出异常或者直接停止训练。_LoggerHook()
,该hook是我们自定义的hook,用来监测我们希望在训练过程中能查看的一些数据如loss或者accuracy。首先会随着MonitoredTrainingSession的初始化来调用begin()
函数,我们在这里初始化步数,before_run()
函数会随着sess.run()
的调用而调用。故每训练一步调用一次,这里返回想要打印的信息,随后就调用after_run()
函数,在这里,我们就将需要查看的信息打印出来即可。随后,我们开启文件读取队列进行数据的输入。然后就一直调用sess.run()
训练直到停下。
首先得生成tfrecords文件,在当前文件夹下新建一个create_tfrecords.py,然后将下面的代码放进去(其实就是上面的代码)
import tensorflow as tf
import numpy as np
import os
from scipy.misc import imread,imresize
from os.path import join
from os import walk
IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
IMAGE_CHANNEL = 3
NUM_CLASS = 2
def read_images(path):
"""Read image from source file/directory
Args:
path: source derectory
Return:
An object representing all images and labels, fields:
images: all image data
labels: all labels
num: number of images
"""
# Get a list filenames
filenames = next(walk(path))[2]
num_file = len(filenames)
# Initialize images and labels.
images = np.zeros((num_file, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL), dtype=np.uint8)
labels = np.zeros((num_file, ), dtype=np.uint8)
# Iterate/Read all files
for index, filename in enumerate(filenames):
# Read single image and resize it to your expected size
img = imread(join(path, filename))
img = imresize(img, (IMAGE_HEIGHT, IMAGE_HEIGHT))
images[index] = img
# TO DO:
if filename[0:3] == 'cat':
labels[index] = int(0)
else:
labels[index] = int(1)
if index % 1000 == 0:
print("Reading the %sth image" % index)
class ImgData(object):
pass
result = ImgData()
result.images = images
result.labels = labels
result.num = num_file
return result
def convert(data, destination):
"""Convert images to tfrecords
Args:
data: an object of ImgData, consisting of images, labels and number of images
destination: destination filename of tfrecords
"""
images = data.images
labels = data.labels
num_examples = data.num
# filenale of tfrecords
filename = destination
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image = images[index].tostring()
label = labels[index]
# Attention: Example -> Features -> Feature
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
path = 'kaggle/train'
tfrecords_path = 'tfrecords/cat_dog.tfrecords'
data = read_images(path)
convert(data,tfrecords_paths)
然后直接命令python create_tfrecords.py
然后直接命令python catdog_train.py --tfrecords_name cat_dog \ --max_step 5000
运行结果: