Tensorflow建立与读取TFrecorder文件

Tensorflow建立与读取TFrecorder文件

除了直接读取数据文件,比如csv和bin文件,tensorflow还可以建立一种自有格式的数据文件,称之为tfrecorder,这种文件储存类似于字典,调用方便,可以直接包含标签集。

首先,要建立起tfrecorder文件,我这里选择了若干人脸图像数据,文件的组织形式为根目录/s+类别号/图片名称
Tensorflow建立与读取TFrecorder文件_第1张图片

基本思路是先遍历文件夹,使用PIL库读取文件,然后写入到tfrecorder中。

import tensorflow as tf
import numpy as np
from PIL import Image

import os

cwd = os.getcwd()

root = cwd+"/data/face_data"

TFwriter = tf.python_io.TFRecordWriter("./data/faceTF.tfrecords")

for className in os.listdir(root):
    label = int(className[1:])
    classPath = root+"/"+className+"/"
    for parent, dirnames, filenames in os.walk(classPath):
        for filename in filenames:
            imgPath = classPath+"/"+filename
            print (imgPath)
            img = Image.open(imgPath)
            print (img.size,img.mode)
            imgRaw = img.tobytes()
            example = tf.train.Example(features=tf.train.Features(feature={
                "label":tf.train.Feature(int64_list = tf.train.Int64List(value=[label])),
                "img":tf.train.Feature(bytes_list = tf.train.BytesList(value=[imgRaw]))
            }) )
            TFwriter.write(example.SerializeToString())

TFwriter.close()

此时就会生成tfrecorder文件了,需要使用时,和之前一样,使用队列将其读出就可以了。


fileNameQue = tf.train.string_input_producer(["./data/faceTF.tfrecords"])
reader = tf.TFRecordReader()
key,value = reader.read(fileNameQue)
features = tf.parse_single_example(value,features={ 'label': tf.FixedLenFeature([], tf.int64),
                                           'img' : tf.FixedLenFeature([], tf.string),})

img = tf.decode_raw(features["img"], tf.uint8)
label = tf.cast(features["label"], tf.int32)

init = tf.initialize_all_variables()

with tf.Session() as sess:

    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(100):
        imgArr = sess.run(img)
        print (imgArr)

    coord.request_stop()
    coord.join(threads)

注意这样几点:

  1. example是用tf.train.Example创建的一个样本实例,也就是一张图片的记录,其中包含有记录的属性,是tf.train.Features创建的一个实例。
  2. 解析tfrecoreder文件的解析器是 parse_single_example,阅读器是tf.TFRecordReader
  3. 注意,这里的多线程队列读取的运行机制是,管道启动,sess每run一次img节点就会执行一次操作,因为fetch到了数据,所以队列就弹出。
  4. 注意,当run后拿到的结果一般都是numpy格式的数据,不再是图上的节点和在图中流动的tensor,而是实实在在可以graph外部操作的数据了。

你可能感兴趣的:(机器学习)