TFRecords文件的存储与读取

将cats和dogs两个文件夹各1000张图片存储为:train.tfrecords
#将图片文件生成train record
import os
import tensorflow as tf
from PIL import Image
#生成catsdogsrecord文件
path='./data/train'
filenames=os.listdir(path)
writer=tf.python_io.TFRecordWriter('./data/train/train.tfrecords')
classes=['cats','dogs']#两类
for index,name in enumerate(classes):
    print(index,name)
#for name in os.listdir(path):
    class_path=path+os.sep+name
    for img_name in os.listdir(class_path):
        img_path=class_path+os.sep+img_name
        img=Image.open(img_path)
        img=img.resize((500,500))
        img_raw=img.tobytes()#图片转化成二进制
        example=tf.train.Example(features=tf.train.Features(
            feature={
            'label':
            tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'image':
            tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
        writer.write(example.SerializeToString())

TFRecords文件的存储与读取_第1张图片

变为

TFRecords文件的存储与读取_第2张图片

读取过程:

# 读取catsdogstrain.tfrecords文件
import tensorflow as tf
import cv2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

filename='./data/train/train.tfrecords'
#读取并解析.tfrecords文件
def read_and_decode(filename):
    filename_queue=tf.train.string_input_producer([filename])# 按队列的形式读取
    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(filename_queue)#返回文件名和文件
    features=tf.parse_single_example(serialized_example,
                            features={
                                'label':tf.FixedLenFeature([],tf.int64),#与存储的类型一致
                                'image':tf.FixedLenFeature([],tf.string)
                            })
    img=tf.decode_raw(features['image'],tf.uint8)
    img=tf.reshape(img,shape=[500,500,3])
    img = tf.cast(img, dtype=tf.float32) * (1.0 / 128) - 0.5
    label = tf.cast(features['label'], dtype=tf.int32)
    return img,label

img,label=read_and_decode(filename)

img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=1,
                                             capacity=10,min_after_dequeue=1)
init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    threads=tf.train.start_queue_runners(sess=sess)
    for i in range(10):
        label=sess.run(label_batch)
        imgcv2=sess.run(img_batch)
        imgcv2.resize((500,500,3))
        print(label)
        cv2.imshow('img',imgcv2)
        cv2.waitKey()
按下任意键即可切换图片 共10张是cats,label应该是0

10张打印结果都是0,吻合前面在定义类的时候 0是cats的标签.

TFRecords文件的存储与读取_第3张图片,,

你可能感兴趣的:(TFRecords文件的存储与读取)