Tensorflow 中Tfrecords的使用心得

  这篇文章主要讲一下如何用Tensorflow中的标准数据读取方式简单的实现对自己数据的读取操作.

主要分为以下两个步骤:(1)将自己的数据集转化为 xx.tfrecords的形式;(2):在自己的程序中读取并使用.tfrecords进行操作.

数据集转换:为了便于讲解,我们简单制作了一个数据,如下图所示:

Tensorflow 中Tfrecords的使用心得_第1张图片

程序:

import tensorflow as tf
import numpy as np
import os
from PIL import Image
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def img_to_tfrecord(data_path):
    rows = 256
    cols = 256
    depth = 3
    writer = tf.python_io.TFRecordWriter('test.tfrecords')
    labelfile=open("random.txt")
    lines=labelfile.readlines()
    for line in lines:
        #print line
        img_name = line.split(" ")[0]#name
        label = line.split(" ")[1]#label
        img_path = data_path+img_name
        img = Image.open(img_path)
        img = img.resize((rows,cols))
        #img_raw = img.tostring()    
        img_raw = img.tobytes() 
        example = tf.train.Example(features = tf.train.Features(feature = {
                            'height': _int64_feature(rows),
                           'weight': _int64_feature(cols),
                            'depth': _int64_feature(depth),
                        'image_raw': _bytes_feature(img_raw),
                'label': _bytes_feature(label)}))
                
            writer.write(example.SerializeToString())    
    writer.close() 



if __name__ == '__main__':
    current_dir = os.getcwd()    
    data_path = current_dir + '/data/'    
    #name = current_dir + '/data'
    print('Convert start')   
    img_to_tfrecord(data_path)
    print('done!')

运行该段程序可以看到在dataset_tfrecord文件夹下面有test.tfrecord文件生成。
在TF的Session中调用这个生成的文件

#encoding=utf-8 
# 设置utf-8编码,方便在程序中加入中文注释.
import os
import scipy.misc
import tensorflow as tf
import numpy as np
from test import *
import matplotlib.pyplot as plt

def read_and_decode(filename_queue):
        
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    
    features = tf.parse_single_example(serialized_example,features = {
                        'image_raw':tf.FixedLenFeature([], tf.string)})
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, [OUTPUT_SIZE, OUTPUT_SIZE, 3])
    image = tf.cast(image, tf.float32)
    #image = image / 255.0
    
    return image

data_dir = '/home/sanyuan/dataset_animal/dataset_tfrecords/' 

filenames = [os.path.join(data_dir,'train%d.tfrecords' % ii) for ii in range(1)] #如果有多个文件,直接更改这里即可
filename_queue = tf.train.string_input_producer(filenames)
image = read_and_decode(filename_queue)
with tf.Session() as sess:    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in xrange(2):
        img = sess.run([image])
        print(img[0].shape)  # 设置batch_size等于1.每次读出来只有一张图
        plt.imshow(img[0])
        plt.show()
    coord.request_stop()
    coord.join(threads)
	

程序到这里就已经处理完成了,当然在decorde的过程中也是可以进行一些预处理操作的,不过建议还是在制作数据集的时候进行,TFrecord使用的是队列的方式进行读取数据,这个对于多线程操作来说还是很方便的,只需要设置好格式,每次直接读取就可以了.

你可能感兴趣的:(Tensorflow学习心得)