一个tensorflow多线程读入数据和数据增强的框架:

总是跑cifar10 mnist啥的,其他数据集还是跑一跑把
闲着没事把数据读取部分重写了一遍
代码比以前要更加通用一点
考虑到大部分数据集都可以有一下的储存形式
(我把cifar10图片也像这样分类了)
一个tensorflow多线程读入数据和数据增强的框架:_第1张图片
一个tensorflow多线程读入数据和数据增强的框架:_第2张图片
一个tensorflow多线程读入数据和数据增强的框架:_第3张图片
写了一个较为通用的数据读取和预处理的class
在处理之前会转化为TFRecord并放在TFRecord文件夹内,这个步骤会对图片resize(本地图片大小随意)
图片统一mode为RGB(此处应有故事…各种model的储存方式不同)
采用多线程读入数据集,并自动为每个类别分配标签,最后分batch输出
可以选择是否进行数据增强
可以设置图片固定的输出大小
当然还是搞了一些提示语句(英语不好,瞎写了一些)
一个tensorflow多线程读入数据和数据增强的框架:_第4张图片
代码在这儿:
希望能帮到大家

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.utils import shuffle
import os
class DataReader():
    def __init__(self,data_path,enhance=True,output_size=None,min_after_dequeue=350,num_threads=1):
        # data_path:数据集文件夹,文件夹内应有train和test文件夹
        #enhance是否数据增强 
        #ourpur_size 输出图片大小 
        #min_after_dequeue 队列最少数量
        #num_threads 读取文件的线程数
        self.num_threads=num_threads
        self.enhance=enhance
        self.min_after_dequeue=min_after_dequeue
        self.output_size=output_size
        self.train_data_path=os.path.join(data_path,"train")
        self.test_data_path=os.path.join(data_path,"test")
        self.tf_path=os.path.join(data_path,"TFRecord")
        self.tf_train_path=os.path.join(self.tf_path,"train")
        self.tf_test_path=os.path.join(self.tf_path,"test")
        self.classes=os.listdir(self.train_data_path) 
        if not os.path.exists(self.tf_path):os.makedirs(self.tf_path) #创建好TFRecord的保存路径
        if not os.path.exists(self.tf_train_path):os.makedirs(self.tf_train_path)
        if not os.path.exists(self.tf_test_path):os.makedirs(self.tf_test_path)
        self.detail()
    def class_list(self):#标签对应类别
        return self.classes
    def Batcher(self,batch_size,path,distorted):
        filename=os.listdir(path)
        for i in range(len(filename)):filename[i]=os.path.join(path,filename[i])
        filename_queue=tf.train.string_input_producer(filename,shuffle=True,num_epochs=None)
        reader=tf.TFRecordReader()
        _,serialized_examples=reader.read(filename_queue)
        feature=tf.parse_single_example(serialized_examples,features={
                "image":tf.FixedLenFeature([],tf.string),
                "label":tf.FixedLenFeature([],tf.int64)
            })
        if self.enhance:image_size=[int(self.output_size[0]*1.3),int(self.output_size[1]*1.3),3] #有数据增强的话需要随机剪裁
        else:image_size=[self.output_size[0],self.output_size[1],3]
        image=tf.decode_raw(feature["image"],tf.uint8)
        image=tf.reshape(image,image_size)
        image=tf.random_crop(image,[self.output_size[0],self.output_size[1],3])
        if self.enhance and distorted: #训练集和测试集都要随机剪裁,不过不用后面的步骤
            image=tf.image.random_flip_left_right(image)
            image=tf.image.random_flip_up_down(image)
            image=tf.image.random_brightness(image,max_delta=0.3)
            image=tf.image.random_contrast(image,lower=0.7,upper=1.3)
            image=tf.image.random_saturation(image,lower=0.7,upper=1.3)
            image=tf.image.random_hue(image,max_delta=0.4)
        image=tf.image.per_image_standardization(image)
        label=tf.cast(feature["label"],tf.int32)
        images,labels=tf.train.shuffle_batch([image,label],
                                             num_threads=self.num_threads,
                                             batch_size=batch_size,
                                             min_after_dequeue=self.min_after_dequeue,
                                             capacity=self.min_after_dequeue+batch_size*3)
        labels=tf.one_hot(labels,2)
        return images,labels
    def train_batch(self,batch_size):
        file_path=os.path.join(self.tf_train_path,"data.TFRecord")
        if not os.path.exists(file_path):
            print("file:"+file_path+" cannot be found,ready to create")
            self.CreatTFRecordData(self.train_data_path,self.tf_train_path)
        else:
            print("find file:"+file_path)
        return self.Batcher(batch_size,self.tf_train_path,True)
    def test_batch(self,batch_size):
        file_path=os.path.join(self.tf_test_path,"data.TFRecord")
        if not os.path.exists(file_path):
            print("file:"+file_path+" cannot be found\nready to create")
            self.CreatTFRecordData(self.test_data_path,self.tf_test_path)
        else:
            print("find file:"+file_path)
        return self.Batcher(batch_size,self.tf_test_path,False)
    def detail(self):
        self.num_examples_for_train=0
        self.num_examples_for_test=0
        print("dataset detail:")
        print("train_data:")
        for label,image_class in enumerate(self.classes):
            class_path=os.path.join(self.train_data_path,image_class)
            print("class:%s num:%d"%(image_class,len(os.listdir(class_path))))
            self.num_examples_for_train+=len(os.listdir(class_path))
        print("test_data")
        for label,image_class in enumerate(self.classes):
            class_path=os.path.join(self.test_data_path,image_class)
            print("class:%s num:%d"%(image_class,len(os.listdir(class_path))))
            self.num_examples_for_test+=len(os.listdir(class_path))
    def CreatTFRecordData(self,data_path,save_path):
        print("extracting from "+data_path+"\\*.jpg")
        image_list=[]
        label_list=[]
        image_total=0
        if self.enhance:
            shape=(int(self.output_size[0]*1.3),int(self.output_size[1]*1.3))
        else:
            shape=self.output_size
        for label,image_class in enumerate(self.classes):
            class_path=os.path.join(data_path,image_class)
            for image_name in os.listdir(class_path):
                image_path=os.path.join(class_path,image_name)
                image_list.append(image_path)
                label_list.append(label)
                image_total+=1
        for i in range(3):image_list,label_list=shuffle(image_list,label_list)
        writer=tf.python_io.TFRecordWriter(os.path.join(save_path,"data.TFRecord"))
        for i in range(image_total):
            if (i+1)%2500==0:print("Create "+os.path.join(save_path,"data.TFRecord")+": %.1f %%"%(i*100/image_total))
            image=Image.open(image_list[i]).resize(shape,Image.BICUBIC)
            image=image.convert("RGB")
            image=image.tobytes()
            example=tf.train.Example(features=tf.train.Features(feature={
                        "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                        "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label_list[i]]))
                    }))
            writer.write(example.SerializeToString())
        writer.close()
        print("finish")
'''
cifar10=DataReader(data_path,output_size=(64,64),enhance=True)
cifar10.detail()                                #各个类别数量统计
classes=cifar10.class_list()                    #等于cifar10.classes 得到标签对应的类别名称
train_batch=cifar10.train_batch(batch_size=100) #训练数据batch
test_batch=cifar10.test_batch(batch_size=100)   #测试数据batch
'''

显示图片:

def imshows(classes,images,labels,index,amount,predictions=None):
    #classes 类别数组
    #image 图片数组
    #labels 标签数组
    #index amount 从数组第index开始输出amount张照片
    #prediction 预测结果
    fig=plt.gcf()
    fig.set_size_inches(10,20)#大小看怎么调整合适把
    for i in range(amount):
        title="lab:"+classes[np.argmax(labels[index+i])]
        if predictions is not None:
            title=title+"prd:"+name[np.argmax(predictions[index+i])]
        ax=plt.subplot(5,6,i+1)#每行五个,输出6行
        ax.set_title(title)
        ax.imshow(images[index+i])
    plt.show()

测试:

data_path="D:\\log\\cifar10\\Image"
batch_size=100
cifar10=DataReader(data_path,output_size=(64,64),enhance=True)
cifar10.detail()                                
train_batch=cifar10.train_batch(batch_size=batch_size) 
test_batch=cifar10.test_batch(batch_size=batch_size)   
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    coord=tf.train.Coordinator()
    threads=tf.train.start_queue_runners(sess=sess,coord=coord)
    for i in range(10):
        img,lab=sess.run(train_batch)
        imshows(cifar10.classes,img,lab,0,5)
    coord.request_stop()
    coord.join(threads)

效果:
(没有标准化,有数据增强)
一个tensorflow多线程读入数据和数据增强的框架:_第5张图片
kaggle的猫狗数据集:(使用standardization)

一个tensorflow多线程读入数据和数据增强的框架:_第6张图片
不过训练的时候会进行标准化%只是标准化之后
说实话我啥也认不出来
emm
就这样了吧

你可能感兴趣的:(一些小模块实现)