Tfrecords文件保存和读取

import os
import tensorflow as tf
from PIL import Image

#保存数据到tfrecords文件
def convert2tfr(path, name):
    classes = 3  # 类别数目
    writer = tf.python_io.TFRecordWriter(name + '.tfrecords')  # 要生成的文件
    for index in range(classes):
        class_path = path + str(index) + '/'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一个图片的地址
            img = Image.open(img_path)
            img = img.convert("RGB")  # 转换成RGB格式
            img = img.resize((32, 32))
            img_raw = img.tobytes()  # 将图片转化为二进制格式
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example对象对label和image数据进行封装
            writer.write(example.SerializeToString())  # 序列化为字符串
    writer.close()

#tensorflow运行读取图片batch函数
def read_and_decode(filename, batch_size):
    filename_queue = tf.train.string_input_producer([filename])  # create a queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # return file_name and file
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # return image and label

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [32, 32, 3])  # reshape image to 32*32*3

    # 3通道转为1通道
    img = tf.image.rgb_to_grayscale(img)  # 图像灰度化 32*32*1
    # img = tf.reshape(img, [32,32])  #reshape image to 32*32

    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5  # throw img tensor
    label = tf.cast(features['label'], tf.int64)  # throw label tensor
    img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=batch_size,
                                                    capacity=1000 + batch_size * 3, min_after_dequeue=1000)
    label_batch = tf.one_hot(label_batch, depth=81)
    return img_batch, label_batch

#解析tfrecords文件变成图片源
def tfr2bmp(filename, dir, num):  # tfr文件名,解析后存放的目录,图片数量
    if not os.path.exists('read_img'):
        os.mkdir('read_img')
    os.mkdir('read_img/' + dir)
    filename_queue = tf.train.string_input_producer([filename])  # 读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [32, 32, 3])
    label = tf.cast(features['label'], tf.int32)
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(num):
            example, l = sess.run([image, label])  # 在会话中取出image和label
            img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
            img.save('read_img/' + dir + '/' + str(i) + '_Label_' + str(l) + '.bmp')  # 存下图片
            print(example, l)
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # 将图片转换为TFR格式
    #convert2tfr('G:/tfrecord_test/', 'tfrecord_test')
    #convert2tfr('test_img/', 'tai_test1')

    # 读取TFR格式数据
    # read_and_decode('tai_test.tfrecords',batch_size)
    # read_and_decode('tai_train.tfrecords',batch_size)

    # 提取TFR格式数据并保存
    tfr2bmp(filename='tfrecord_test.tfrecords', dir='test_img', num=15)
    #tfr2bmp(filename='tai_train.tfrecords', dir='train_img', num=48961)

refer to 博客

你可能感兴趣的:(Tensorflow学习之路)