Tensorflow图像分类数据集读取处理和分批次的代码通用框架(tf.data APIs)

数据集有很多,但我希望能有一个通用的框架来一劳永逸
不知道tf自己有没有,我似乎是没找到
花了半天来写了一个看起来比较通用的框架,顺便加了点进度条和提示用语
要求数据集的结构是
一个文件夹内有train和test两个文件,内部每个类别一个文件夹,文件夹内是图片数据
应该大部分都是这么结构把
其他的遇到再说
自动分配标签,classes就是每个标签对应的含义
train_batch和test_batch输出是one_hot标签
速度很快,cifar10跑一边仅需要10秒左右

import tensorflow as tf
from PIL import Image                     #处理图片
from sklearn.utils import shuffle         #打乱图片顺序
from tqdm import tqdm_notebook as tqdm    #这个是为了显示进度条
import os
from time import time
data_path="D:\\log\\cifar10\\Image"
tf.reset_default_graph()
class DatasetReader(object):
    def __init__(self,data_path,image_size=None):        #图片输出大小固定为image_size
        self.data_path=data_path
        self.img_size=image_size
        self.img_size.append(3)
        self.train_path=os.path.join(data_path,"train")  #图片和保存TFRecoed的地址
        self.test_path=os.path.join(data_path,"test")    #使用TFRecord能加快数据输入的速度
        self.TF_path=os.path.join(data_path,"TFRecordData")
        self.tf_train_path=os.path.join(self.TF_path,"train")
        self.tf_test_path=os.path.join(self.TF_path,"test")
        self.classes=os.listdir(self.train_path)
        self.__Makedirs()
        self.train_batch_initializer=None                #使用tf.data的迭代器需要显示的初始化
        self.test_batch_initializer=None
        self.__CreateTFRecord(self.train_path,self.tf_train_path)
        self.__CreateTFRecord(self.test_path,self.tf_test_path)
    def __CreateTFRecord(self,read_path,save_path):
        path=os.path.join(save_path,"data.TFRecord")
        if os.path.exists(path):
            print("find file "+(os.path.join(save_path,"data.TFRecords"))) #已存在就跳过
            return
        else: 
            print("cannot find file %s,ready to recreate"%(os.path.join(save_path,"data.TFRecords")))
        writer=tf.python_io.TFRecordWriter(path=path)
        image_path=[]
        image_label=[]
        image_size=[int(self.img_size[0]*1.5),int(self.img_size[1]*1.5)]  #数据增强会进行随机剪裁,这里先放大图片
        for label,class_name in enumerate(self.classes):
            class_path=os.path.join(read_path,class_name)
            for image_name in os.listdir(class_path):
                image_path.append(os.path.join(class_path,image_name))
                image_label.append(label)
        for i in range(5):image_path,image_label=shuffle(image_path,image_label)
        for i in tqdm(range(len(image_path)),desc="TFRecord"):
            image,label=Image.open(image_path[i]).resize(image_size,Image.BICUBIC),image_label[i]
            image=image.convert("RGB")
            image=image.tobytes()
            example=tf.train.Example(features=tf.train.Features(feature={
                        "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                        "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
                    }))
            writer.write(example.SerializeToString())
        writer.close()
    def __Makedirs(self):#文件夹不存在就创建它
        if not os.path.exists(self.TF_path):
            os.makedirs(self.TF_path)
        if not os.path.exists(self.tf_train_path):
            os.makedirs(self.tf_train_path)
        if not os.path.exists(self.tf_test_path):
            os.makedirs(self.tf_test_path)
    def __parsed(self,tensor):#对tensor进行解码得到图片,剪裁和标准化
        raw_image_size=[int(self.img_size[0]*1.5),int(self.img_size[1]*1.5),3]
        feature=tf.parse_single_example(tensor,features={
                    "image":tf.FixedLenFeature([],tf.string),
                    "label":tf.FixedLenFeature([],tf.int64)
            })
        image=tf.decode_raw(feature["image"],tf.uint8)
        image=tf.reshape(image,raw_image_size)
        image=tf.random_crop(image,self.img_size)
        image=tf.image.per_image_standardization(image)
        label=tf.cast(feature["label"],tf.int32)
        label=tf.one_hot(label,10)
        return image,label
    def __parsed_distorted(self,tensor):#加上数据增强部分
        raw_image_size=[int(self.img_size[0]*1.5),int(self.img_size[1]*1.5),3]
        feature=tf.parse_single_example(tensor,features={
                    "image":tf.FixedLenFeature([],tf.string),
                    "label":tf.FixedLenFeature([],tf.int64)
            })
        image=tf.decode_raw(feature["image"],tf.uint8)
        image=tf.reshape(image,raw_image_size)
        image=tf.random_crop(image,self.img_size)
        image=tf.image.random_flip_left_right(image)
        image=tf.image.random_flip_up_down(image)
        image=tf.image.random_brightness(image,max_delta=0.4)
        image=tf.image.random_hue(image,max_delta=0.4)
        image=tf.image.random_contrast(image,lower=0.7,upper=1.3)
        image=tf.image.random_saturation(image,lower=0.7,upper=1.3)
        image=tf.image.per_image_standardization(image)
        label=tf.cast(feature["label"],tf.int32)
        label=tf.one_hot(label,10)
        return image,label
    def __GetBatchIterator(self,path,parsed,batch_size):#得到next_batch和initializer
        filename=[os.path.join(path,name)for name in os.listdir(path)]
        dataset=tf.data.TFRecordDataset(filename)
        dataset=dataset.map(parsed)
        dataset=dataset.shuffle(buffer_size=500)
        dataset=dataset.batch(batch_size)
        dataset=dataset.repeat(None)
        iterator=dataset.make_initializable_iterator()
        return iterator.initializer,iterator.get_next()
  '''
  tf.data里面对dataset有一些优化,速度会快很多,但是版太低可能没有这几个API
      def __GetBatchIterator(self,path,parsed,batch_size):
        filename=[os.path.join(path,name)for name in os.listdir(path)]
        dataset=tf.data.TFRecordDataset(filename)
        dataset=dataset.prefetch(tf.contrib.data.AUTOTUNE)
        dataset=dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=1000,count=None,seed=233))
        dataset=dataset.apply(tf.data.experimental.map_and_batch(parsed,batch_size))
        dataset=dataset.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
        iterator=dataset.make_initializable_iterator()
        return iterator.initializer,iterator.get_next()
  '''
    def __detail(self,path):
        Max=-1e9
        Min=1e9
        print("train dataset:")
        path=[os.path.join(path,name)for name in self.classes]
        for i in range(len(self.classes)):
            num=len(os.listdir(path[i]))
            print("%-12s:%3d"%(self.classes[i],num))
            Max=max(Max,num)
            Min=min(Min,num)
        print("max:%d min:%d"%(Max,Min))
    def detail(self):#统计dataset的一些信息
        self.__detail(self.train_path)
        self.__detail(self.test_path)
    def global_variables_initializer(self):#initializer放在一起
        initializer=[]
        initializer.append(self.train_batch_initializer)
        initializer.append(self.test_batch_initializer)
        initializer.append(tf.global_variables_initializer())
        return initializer
    def test_batch(self,batch_size):
        self.test_batch_initializer,batch=self.__GetBatchIterator(self.tf_test_path,self.__parsed,batch_size)
        return batch
    def train_batch(self,batch_size):
        self.train_batch_initializer,batch=self.__GetBatchIterator(self.tf_train_path,self.__parsed_distorted,batch_size)
        return batch

在使用的时候就会方便很多

data_path="D:\\log\\cifar10\\Image"
data=DatasetReader(data_path,image_size=[128,128])
train_batch=data.train_batch(batch_size=100)
test_batch=data.test_batch(batch_size=100)
with tf.Session() as sess:
	sess.run(data.global_variables_initalizer())
	image,label=sess.run(train_batch)
	sess.run(train_op,feed_dict={x:image,y:label,training:True})

你可能感兴趣的:(一些小模块实现)