Tensorflow-制作与使用tfrecord数据集

引言

  本次博文目的是记录下tfrecord数据集的制作与使用方式。(踩了无数坑OTZ)
  这里贴上一个数据读取的官方教程:Tensorflow导入数据以及使用数据
  接下来举个例子说明怎么用tfrecord,假设我要做个图片分类的任务。首先,我这里有一个txt文件,包含着所有图片的路径以及它们的标签。还有一个包含许多图片的文件夹。类似下图这样:
Tensorflow-制作与使用tfrecord数据集_第1张图片
Tensorflow-制作与使用tfrecord数据集_第2张图片

  准备好了数据后,就可以制作与使用TFrecored啦~

制作TFrecord

  当然是先写个制作TFrecord的函数啦。我们先读取图片信息的txt文件,得到每个图片的路径以及它们的标签,然后对这个图片作一些预处理,最后将图片以及它对应的标签序列化,并建立图片和标签的索引(即以下代码的”img_raw”, “label”)。详见代码。

import random
import tensorflow as tf
from PIL import Image

def create_record(records_path, data_path, img_txt):
    # 声明一个TFRecordWriter
    writer = tf.python_io.TFRecordWriter(records_path)
    # 读取图片信息,并且将读入的图片顺序打乱
    img_list = []
    with open(img_txt, 'r') as fr:
        img_list = fr.readlines()
    random.shuffle(img_list)
    cnt = 0
    # 遍历每一张图片信息
    for img_info in img_list:
        # 图片相对路径
        img_name = img_info.split(' ')[0]
        # 图片类别
        img_cls = int(img_info.split(' ')[1])
        img_path = data_path + img_name
        img = Image.open(img_path)
        # 对图片进行预处理(缩放,减去均值,二值化等等)
        img = img.resize((128, 128))
        img_raw = img.tobytes()
        # 声明将要写入tfrecord的key值(即图片,标签)
        example = tf.train.Example(
           features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_cls])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
           }))
        # 将信息写入指定路径
        writer.write(example.SerializeToString())
        # 打印一些提示信息~
        cnt += 1
        if cnt % 1000 == 0:
            print "processed %d images" % cnt
    writer.close()

# 指定你想要生成tfrecord名称,图片文件夹路径,含有图片信息的txt文件
records_path = '/the/name/of/your/haha.tfrecords'
data_path = '/the/root/of/your/image_folder/'
img_txt = '/image/labels/list.txt'
create_record(records_path, data_path, img_txt)

使用TFrecord

  目前为止,使用TFrecord最方便的方式是用TensorFlow的Dataset ApI。在这里,劝大家千万千万不要用queue的方式读取数据(麻烦且已经过时)。
  首先,我们定义好_parse_function,这个函数是用来指定TFrecord中索引的(即上文中的”img_raw”, “label”)。然后我们定义一个TFRecordDataset,并借助_parse_function来读取数据。最后,为了得到每一轮的训练数据,我们只需要再额外声明一个iterator,每次调用get_next()就可以啦。

# 定义如何解析TFrecord数据
def _parse_function(example_proto):
    features = tf.parse_single_example(
        example_proto,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        }
    )
    # 取出我们需要的数据(标签,图片)
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    # 对标签以及图片作预处理
    img = tf.reshape(img, [128, 128, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label

# 得到获取data batch的迭代器
def data_iterator(tfrecords):
    # 声明TFRecordDataset
    dataset = tf.contrib.data.TFRecordDataset(tfrecords)
    dataset = dataset.map(_parse_function)
    # 打乱顺序,无限重复训练数据,定义好batch size
    dataset = dataset.shuffle(buffer_size=1000).repeat().batch(128)
    # 定义one_shot_iterator。官方上有许多类型的iterrator,这种是最简单的
    iterator = dataset.make_one_shot_iterator()
    return iterator

# 指定TFrecords路径,得到training iterator。
train_tfrecords = '/your/path/to/haha.tfrecords'
train_iterator = data_iterator(train_tfrecords)

# 使用方式举例
with tf.Session(config= tfconfig) as sess:
    tf.initialize_all_variables().run()
    train_batch = train_iterator.get_next()
    for step in xrange(50000):
        train_x, train_y = sess.run(train_batch)

再聊聊TensorFlow的Slim模块

  这篇文章本该到此结束的。但是我仍想说TensorFlow真的有点难用(也可能是我太弱哈哈)。主要原因是它的API太多,更新速度太快。不过,我们也能迅速学习到许多东西(毕竟它的支持者有很多,这就给我们提供了许多实例以及讲解博文),比如这个关于Slim的学习例子。
  接下来聊聊Slim这个模块,它是2016年出的新模块,目的是减少构建网络的代码量。个人觉得真是很好用,强烈推荐一试!!!(不信可以去上面网址里看看)好的,下面贴一段代码,展示下slim的使用方式,作为本篇的结尾吧~

slim = tf.contrib.slim

def MyNet(inputs, num_classes=7, is_training=True, keep_prob=0.5, scope='MyNet'):
    net = tf.reshape(inputs, [-1, 128, 128, 3])
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                         activation_fn=tf.nn.relu,
                         weights_regularizer=slim.l2_regularizer(0.0005)):
        with slim.arg_scope([slim.conv2d],
                            stride=1,
                            padding='SAME',
                            weights_initializer=tf.contrib.layers.xavier_initializer_conv2d()):
            net = slim.stack(net, slim.conv2d, [(8, [3, 3])], scope='conv1')
            net = slim.max_pool2d(net, [2, 2], scope='pool1')
            net = slim.stack(net, slim.conv2d, [(16, [3, 3]), (24, [3, 3])], scope='conv2')
            net = slim.max_pool2d(net, [2, 2], scope='pool2')
            net = slim.stack(net, slim.conv2d, [(24, [3, 3]), (24, [3, 3]), (36, [3, 3]), (36, [3, 3])], scope='conv3')
        net = slim.flatten(net)
        with slim.arg_scope([slim.fully_connected],
                             weights_initializer=tf.random_normal_initializer(stddev=0.01)):
            net = slim.fully_connected(net, 2048, scope='fc6')
            net = slim.dropout(net, keep_prob, scope='dropout6')
            net = slim.fully_connected(net, 2048, scope='fc7')
            net = slim.dropout(net, keep_prob, scope='dropout7')
            net = slim.fully_connected(net, num_classes, activation_fn=None, scope='fc8')

    return net

input_data = tf.placeholder(tf.float32, [None, 128, 128, 3])
output_logits = MyNet(input_data)

你可能感兴趣的:(工具使用简介)