读tfrecords文件,一个一个读/按批次读

import tensorflow as tf
import matplotlib.pyplot as plt

tfrecords_file = '/home/lw/workspace/MicrovideoLSTM/tfrecordData/videoframe.tfrecords'

filename_queue = tf.train.string_input_producer([tfrecords_file])  # 根据文件名生成一个队列
reader = tf.TFRecordReader()                                       # TFRecordReader 用于读取 TFReacord
_, serialized_example = reader.read(filename_queue)                # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'image_raw' : tf.FixedLenFeature([], tf.string),
                                   })                              # 取出包含image和label的feature对象..

image = tf.decode_raw(features['image_raw'], tf.uint8)             # 要结合自己的数据大小来选择tf.uint8,tf.int32

image = tf.reshape(image, [20,20])   # image = tf.reshape(image, [128, 128, 3]) ]要与具体的图像大小保持一致,取灰度图与彩色图

label = tf.cast(features['label'], tf.int32)                       # 读取出标签数据

image_batch, label_batch = tf.train.shuffle_batch([image, label],
                                            batch_size=10, 
                                            capacity=2000,
                                            min_after_dequeue=1000)# 生成批次

with tf.Session() as sess:                                         # 开始一个会话
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    coord=tf.train.Coordinator()                                   # 设置多线程协调器
    threads= tf.train.start_queue_runners(sess = sess, coord=coord)# 开始 Queue Runners (队列运行器)
    
    for i in range(2):
        #example, l = sess.run([image,label])              # 在会话中取出image和label              在队列中一个一个取
        example, l = sess.run([image_batch,label_batch])   # 在会话中取出image_batch和label_batch  在队列中按批次取,维度不同
        print example.shape
        print l.shape
        print l
        # plt.imshow(example)                              # 显示单张图像 
        plt.imshow(example[i,:,:])                         # 在批次里面显示单张图像
        plt.show()

    coord.request_stop()
    coord.join(threads)

你可能感兴趣的:(读tfrecords文件,一个一个读/按批次读)