实现TFrecords文件的保存与读取

实现TFrecords文件的保存与读取_第1张图片

import os
import cv2
import numpy as np
import tensorflow as tf
"""
train文件夹下的catsdog文件夹处理成train.tfrecords放在train文件夹里
"""
#将图片的路径和对应的标签存储在list中返回
def deal(dir):
    images = []
    temp = []
    for root,dirs,files in os.walk(dir):
        for name in files:
            images.append(os.path.join(root,name))
        for name in dirs:#dogs cats文件夹形式读取
            temp.append(os.path.join(root,name))

    labels=[]
    for one_folder in temp:
        n_img=len(os.listdir(one_folder))#展开cats或者dogs的图片
        letter=one_folder.split('/')[-1]
        if letter=='cats':
            labels=np.append(labels,n_img*[0])#np.append拼接 0cat 1dog
        else:
            labels=np.append(labels,n_img*[1])
    #打乱
     temp=np.array([images,labels])
    temp=temp.transpose()
    np.random.shuffle(temp)
    image_list=list(temp[:,0])
    label_list=list(temp[:,1])
    label_list=[int(float(i)) for i in label_list]
    return image_list,label_list
#返回整形特征
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#返回bytes特征
def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to_tfrecord(image_list,label_list):
    n_samples=len(label_list)
    writer = tf.python_io.TFRecordWriter('./data/train/train.tfrecords')
    print('start transform')
    for i in range(n_samples):
        try:
            img=cv2.imread(image_list[i])
            img_raw = img.tobytes()  # 图片转化成二进制
            label=int(label_list[i])
            example=tf.train.Example(features=tf.train.Features(
                    feature={
                    'label':int64_feature(label),
                    'image':bytes_feature(img_raw)
                }))
            writer.write(example.SerializeToString())
        except:
            print(image_list[i])
            os.remove(image_list[i])
        writer.close()
        print('transform end')
"""
下面是读取tfrecord和显示图片证明生成的tfrecord正确
"""
filename='./data/train/train.tfrecords'
#读取并解析.tfrecords文件
def read_and_decode(filename):
    filename_queue=tf.train.string_input_producer([filename])# 按队列的形式读取
    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(filename_queue)#返回文件名和文件
    features=tf.parse_single_example(serialized_example,
                            features={
                                'label':tf.FixedLenFeature([],tf.int64),#与存储的类型一致
                                'image':tf.FixedLenFeature([],tf.string)
                            })
    img=tf.decode_raw(features['image'],tf.uint8)
    img=tf.reshape(img,shape=[227,227,3])
    #img = tf.cast(img, dtype=tf.float32) * (1.0 / 128) - 0.5
    label = tf.cast(features['label'], dtype=tf.int32)
    return img,label
def show():
    img,label=read_and_decode(filename)

    img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=1,
                                                 capacity=11,min_after_dequeue=5)
    init=tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        threads=tf.train.start_queue_runners(sess=sess)
        for i in range(10):
            label=sess.run(label_batch)
            imgcv2=sess.run(img_batch)
            imgcv2.resize((227,227,3))
            print(label)
            cv2.imshow('img',imgcv2)
            cv2.waitKey()

if __name__ == '__main__':
    # image_list, label_list = deal('./data/train')
    # convert_to_tfrecord(image_list,label_list)
    show()

你可能感兴趣的:(实现TFrecords文件的保存与读取)