读取tfrecord数据

有了tfrecord数据集,自然就要从数据集中把原始训练数据解出来进行训练,tensorflow提供了一整套方法来处理tfrecord数据集的读取,包括读取函数和多线程处理数据的方法。

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def read_and_decode(filename):
    #"C:\\Python34\\tensorflow\\tfrecord_train_plane_1.tfrecords"
    filename_queue = tf.train.string_input_producer([filename], shuffle = False) #使用初始化时提供的文件列表创建一个输入队列,输入队列中原始的元素为文件列表中的所有文件

    reader = tf.TFRecordReader()#创建一个reader来读取TFRecord文件中的样例
    _, serialized_example = reader.read(filename_queue)   #从文件中读出一个样例,返回文件名和文件
    #batch = tf.train.batch(tensors=[serialized_example],batch_size=3)
    #解析读入的一个样例
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([3], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })  #取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.uint8)#采用decode_raw将字符串解析成图像对应的像素数组
    image = tf.reshape(image, [128, 128, 1])#将读入的数据重新整理为128*128的图像,1为图像的通道数
    image = (image-tf.reduce_min(image))/(tf.reduce_max(image)-tf.reduce_min(image))#将图像数据归一化
    label = tf.cast(features['label'], tf.float32)#将标签数据改为实数型
    return image, label
if __name__=="__main__":
    img, label = read_and_decode("C:\\tensorflowprogram\\tensorflow\\tfrecord_train_plane\\tfrecord_train_plane_128_330.tfrecords")
    count = 0
    #采用tf.train.shuffle_batch函数来将单个的样例组织成batch的形式输出,[img,label]给出了需要组合的元素,batch_size为每次出队得到的样例数量,capacity给出了队列的最大容量,min_after_dequeue限制了出队时最少元素的个数
    img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=20, capacity=1060,min_after_dequeue = 30)                                           
    with tf.Session() as sess: #开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord=tf.train.Coordinator()#声明tf.train.Coordinator来协同不同的进程,并启动线程
        threads= tf.train.start_queue_runners(coord=coord)
         #测试程序,测试程序是否能批量读入数据
        for i in range (15):
            k,l=sess.run([img_batch,label_batch])
            print(type(k))
            print(k)
            print(l)
           #print(type(l))
            print(i)

        coord.request_stop()
        coord.join(threads)

经过测试,该程序合格,能够不间断地读取数据,用来输入神经网络模型进行训练。

你可能感兴趣的:(tensorflow)