本次博文目的是记录下tfrecord数据集的制作与使用方式。(踩了无数坑OTZ)
这里贴上一个数据读取的官方教程:Tensorflow导入数据以及使用数据
接下来举个例子说明怎么用tfrecord,假设我要做个图片分类的任务。首先,我这里有一个txt文件,包含着所有图片的路径以及它们的标签。还有一个包含许多图片的文件夹。类似下图这样:
准备好了数据后,就可以制作与使用TFrecored啦~
当然是先写个制作TFrecord的函数啦。我们先读取图片信息的txt文件,得到每个图片的路径以及它们的标签,然后对这个图片作一些预处理,最后将图片以及它对应的标签序列化,并建立图片和标签的索引(即以下代码的”img_raw”, “label”)。详见代码。
import random
import tensorflow as tf
from PIL import Image
def create_record(records_path, data_path, img_txt):
# 声明一个TFRecordWriter
writer = tf.python_io.TFRecordWriter(records_path)
# 读取图片信息,并且将读入的图片顺序打乱
img_list = []
with open(img_txt, 'r') as fr:
img_list = fr.readlines()
random.shuffle(img_list)
cnt = 0
# 遍历每一张图片信息
for img_info in img_list:
# 图片相对路径
img_name = img_info.split(' ')[0]
# 图片类别
img_cls = int(img_info.split(' ')[1])
img_path = data_path + img_name
img = Image.open(img_path)
# 对图片进行预处理(缩放,减去均值,二值化等等)
img = img.resize((128, 128))
img_raw = img.tobytes()
# 声明将要写入tfrecord的key值(即图片,标签)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_cls])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
# 将信息写入指定路径
writer.write(example.SerializeToString())
# 打印一些提示信息~
cnt += 1
if cnt % 1000 == 0:
print "processed %d images" % cnt
writer.close()
# 指定你想要生成tfrecord名称,图片文件夹路径,含有图片信息的txt文件
records_path = '/the/name/of/your/haha.tfrecords'
data_path = '/the/root/of/your/image_folder/'
img_txt = '/image/labels/list.txt'
create_record(records_path, data_path, img_txt)
目前为止,使用TFrecord最方便的方式是用TensorFlow的Dataset ApI。在这里,劝大家千万千万不要用queue的方式读取数据(麻烦且已经过时)。
首先,我们定义好_parse_function,这个函数是用来指定TFrecord中索引的(即上文中的”img_raw”, “label”)。然后我们定义一个TFRecordDataset,并借助_parse_function来读取数据。最后,为了得到每一轮的训练数据,我们只需要再额外声明一个iterator,每次调用get_next()就可以啦。
# 定义如何解析TFrecord数据
def _parse_function(example_proto):
features = tf.parse_single_example(
example_proto,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
# 取出我们需要的数据(标签,图片)
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
# 对标签以及图片作预处理
img = tf.reshape(img, [128, 128, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(label, tf.int32)
return img, label
# 得到获取data batch的迭代器
def data_iterator(tfrecords):
# 声明TFRecordDataset
dataset = tf.contrib.data.TFRecordDataset(tfrecords)
dataset = dataset.map(_parse_function)
# 打乱顺序,无限重复训练数据,定义好batch size
dataset = dataset.shuffle(buffer_size=1000).repeat().batch(128)
# 定义one_shot_iterator。官方上有许多类型的iterrator,这种是最简单的
iterator = dataset.make_one_shot_iterator()
return iterator
# 指定TFrecords路径,得到training iterator。
train_tfrecords = '/your/path/to/haha.tfrecords'
train_iterator = data_iterator(train_tfrecords)
# 使用方式举例
with tf.Session(config= tfconfig) as sess:
tf.initialize_all_variables().run()
train_batch = train_iterator.get_next()
for step in xrange(50000):
train_x, train_y = sess.run(train_batch)
这篇文章本该到此结束的。但是我仍想说TensorFlow真的有点难用(也可能是我太弱哈哈)。主要原因是它的API太多,更新速度太快。不过,我们也能迅速学习到许多东西(毕竟它的支持者有很多,这就给我们提供了许多实例以及讲解博文),比如这个关于Slim的学习例子。
接下来聊聊Slim这个模块,它是2016年出的新模块,目的是减少构建网络的代码量。个人觉得真是很好用,强烈推荐一试!!!(不信可以去上面网址里看看)好的,下面贴一段代码,展示下slim的使用方式,作为本篇的结尾吧~
slim = tf.contrib.slim
def MyNet(inputs, num_classes=7, is_training=True, keep_prob=0.5, scope='MyNet'):
net = tf.reshape(inputs, [-1, 128, 128, 3])
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(0.0005)):
with slim.arg_scope([slim.conv2d],
stride=1,
padding='SAME',
weights_initializer=tf.contrib.layers.xavier_initializer_conv2d()):
net = slim.stack(net, slim.conv2d, [(8, [3, 3])], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.stack(net, slim.conv2d, [(16, [3, 3]), (24, [3, 3])], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.stack(net, slim.conv2d, [(24, [3, 3]), (24, [3, 3]), (36, [3, 3]), (36, [3, 3])], scope='conv3')
net = slim.flatten(net)
with slim.arg_scope([slim.fully_connected],
weights_initializer=tf.random_normal_initializer(stddev=0.01)):
net = slim.fully_connected(net, 2048, scope='fc6')
net = slim.dropout(net, keep_prob, scope='dropout6')
net = slim.fully_connected(net, 2048, scope='fc7')
net = slim.dropout(net, keep_prob, scope='dropout7')
net = slim.fully_connected(net, num_classes, activation_fn=None, scope='fc8')
return net
input_data = tf.placeholder(tf.float32, [None, 128, 128, 3])
output_logits = MyNet(input_data)