读取TFrecord

需求:读取生成的Tfrecord并展示部分图片.

解决方法:基于tensorflow、cv2、numpy等库完成该功能.

注:改编自网上代码 

1)  编写读取TFRecord的python代码,见下:

import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt

def read_and_decode(filename_queue, shuffle_batch=True):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialized_example, features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    })

    image = tf.decode_raw(features['image_raw'], tf.float32)
    image = tf.reshape(image, [28, 28, 3])
    image = image * 255.0

    labels = features['label']

    if shuffle_batch:
        images, labels = tf.train.shuffle_batch(
            [image, labels],
            batch_size=4,
            capacity=8000,
            num_threads=4,
            min_after_dequeue=2000)
    else:
        images, labels = tf.train.batch([image, labels],
                                        batch_size=4,
                                        capacity=8000,
                                        num_threads=4)
    return images, labels


def TFrcords2Img(tfrecord_filename):
    filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                    num_epochs=3)
    images, labs = read_and_decode(filename_queue)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(1):
            imgs, labs = sess.run([images, labs])
            print ('batch' + str(i) + ': ')
            # print type(imgs[0])

            for j in range(4):
                print(str(labs[j]))
                img = np.uint8(imgs[j])
                plt.subplot(4, 2, j * 2 + 1)
                plt.imshow(img)
            plt.show()

        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    TFrcords2Img('E:/Python/mnist_img_output/a4.tfrecords')



2)  执行,验证效果,见下图所示:

读取TFrecord_第1张图片

你可能感兴趣的:(机器学习)