TensorFlow 读取自己的数据集

文章目录

  • 数据读取
  • feeding
  • 从文件读取数据
    • 步骤
    • 产生文件列表
    • 生成文件队列
    • 可配置的文件名乱序(shuffling)
    • 针对输入文件格式的阅读器
        • CSV文件
        • bin(二进制文件)
        • 将数据转换成 `tfrecords`格式后读取
        • 直接读取图片
  • 预加载数据
  • 参考资料

数据读取

TensorFlow程序读取数据一共有3种方法:
1. 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
2. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
3. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

feeding

  1. placeholder占位符,在计算图中占个位置
  2. 在会话中用feed_dict输入数据
with tf.Session():
  input = tf.placeholder(tf.float32)
  classifier = ...
  print classifier.eval(feed_dict={input: my_python_preprocessing_fn()})

大型feed_dict使用例子

从文件读取数据

步骤

  1. 文件名列表
  2. 可配置的 文件名乱序(shuffling)
  3. 可配置的 最大训练迭代数(epoch limit)
  4. 文件名队列
  5. 针对输入文件格式的阅读器
  6. 纪录解析器
  7. 可配置的预处理器
  8. 样本队列

产生文件列表

产生文件列表,方法如下:
["file0", "file1"] 或者[("file%d" % i) for i in range(2)] 或者[("file%d" % i) for i in range(2)]) 或者tf.train.match_filenames_once

生成文件队列

将文件名列表交给tf.train.string_input_producer 函数. string_input_producer来生成一个先入先出的队列, 文件阅读器会需要它来读取数据。

可配置的文件名乱序(shuffling)

设置string_input_producer函数参数,选择是否乱序,设置迭代次数

针对输入文件格式的阅读器

根据你的文件格式, 选择对应的文件阅读器, 然后将文件名队列提供给阅读器的read方法。阅读器的read方法会输出一个key来表征输入的文件和value其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

过程:选择的文件读取器,读取文件名队列并解码,输入tf.train.shuffle_batch 函数中,生成 batch 队列,传递给下一层。

CSV文件

假如你要读取的文件是像 CSV 那样的文本文件,用的文件读取器和解码器就是 TextLineReaderdecode_csv

bin(二进制文件)

假如你要读取的数据是像 cifar10 那样的 .bin 格式的二进制文件,就用 tf.FixedLengthRecordReader 和 tf.decode_raw 读取固定长度的文件读取器和解码器。
例子:cifar10_input.py详解


将数据转换成 tfrecords格式后读取

如果你要读取的数据是图片,或者是其他类型的格式,那么可以先把数据转换成 TensorFlow 的标准支持格式 tfrecords ,它其实是一种二进制文件,通过修改 tf.train.ExampleFeatures,将protocol buffer序列化为一个字符串,再通过 tf.python_io.TFRecordWriter 将序列化的字符串写入tfrecords,然后再用跟上面一样的方式读取tfrecords,只是读取器变成了tf.TFRecordReader,之后通过一个解析器tf.parse_single_example ,然后用解码器 tf.decode_raw解码。
例子选段

def convert_to(data_set, name):
  """Converts a dataset to tfrecords."""
  images = data_set.images
  labels = data_set.labels
  num_examples = data_set.num_examples

  if images.shape[0] != num_examples:
    raise ValueError('Images size %d does not match label size %d.' %
                     (images.shape[0], num_examples))
  rows = images.shape[1]
  cols = images.shape[2]
  depth = images.shape[3]

  filename = os.path.join(FLAGS.directory, name + '.tfrecords')
  print('Writing', filename)
  with tf.python_io.TFRecordWriter(filename) as writer:
    for index in range(num_examples):
      image_raw = images[index].tostring()
      example = tf.train.Example(
          features=tf.train.Features(
              feature={
                  'height': _int64_feature(rows),
                  'width': _int64_feature(cols),
                  'depth': _int64_feature(depth),
                  'label': _int64_feature(int(labels[index])),
                  'image_raw': _bytes_feature(image_raw)
              }))
writer.write(example.SerializeToString())

tensorflow 官方例子,更好但是更长
Step by Step, A Tutorial on How to Feed Your Own Image Data to Tensorflow

直接读取图片

  • 首先,设置你的 ROOT_PATH。这个路径是带有你的训练数据和测试数据的目录。

  • 接下来,你可以借助 join() 函数为 ROOT_PATH 增加特定的路径。你将这两个特定的路径存储在 train_data_directory 和 test_data_directory 中

  • 之后,你可以调用 load_data() 函数,并将 train_data_directory 作为它的参数。

  • 现在 load_data() 启动并自己开始收集 train_data_directory 下的所有子目录;为此它借助了一种被称为列表推导式(list comprehension)的方法——这是一种构建列表的自然方法。基本上就是说:如果在 train_data_directory 中发现了一些东西,就双重检查这是否是一个目录;如果是,就将其加入到你的列表中。注意:每个子目录都代表了一个标签。

  • 接下来,你必须循环遍历这些子目录。首先你要初始化两个列表:labels 和 imanges。然后你要收集这些子目录的路径以及存储在这些子目录中的图像的文件名。之后,你可以使用 append() 函数来收集这两个列表中的数据。

  • 参考资料-机器之心交通标志数据集

def load_data(data_directory):
    directories = [d for d in os.listdir(data_directory) 
                   if os.path.isdir(os.path.join(data_directory, d))]
    labels = []
    images = []    for d in directories:
        label_directory = os.path.join(data_directory, d)
        file_names = [os.path.join(label_directory, f) 
                      for f in os.listdir(label_directory) 
                      if f.endswith(".ppm")]        for f in file_names:
            images.append(skimage.data.imread(f))
            labels.append(int(d))    return images, labels

ROOT_PATH = "/your/root/path"train_data_directory = os.path.join(ROOT_PATH, "TrafficSigns/Training")
test_data_directory = os.path.join(ROOT_PATH, "TrafficSigns/Testing")

images, labels = load_data(train_data_directory)

预加载数据

参考资料

TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取
tensorflow(二)----线程队列与io操作
极客学院 TensorFlow 官方文档中文版 数据读取

你可能感兴趣的:(TensorFlow)