数据集有很多,但我希望能有一个通用的框架来一劳永逸
不知道tf自己有没有,我似乎是没找到
花了半天来写了一个看起来比较通用的框架,顺便加了点进度条和提示用语
要求数据集的结构是
一个文件夹内有train和test两个文件,内部每个类别一个文件夹,文件夹内是图片数据
应该大部分都是这么结构把
其他的遇到再说
自动分配标签,classes就是每个标签对应的含义
train_batch和test_batch输出是one_hot标签
速度很快,cifar10跑一边仅需要10秒左右
import tensorflow as tf
from PIL import Image #处理图片
from sklearn.utils import shuffle #打乱图片顺序
from tqdm import tqdm_notebook as tqdm #这个是为了显示进度条
import os
from time import time
data_path="D:\\log\\cifar10\\Image"
tf.reset_default_graph()
class DatasetReader(object):
def __init__(self,data_path,image_size=None): #图片输出大小固定为image_size
self.data_path=data_path
self.img_size=image_size
self.img_size.append(3)
self.train_path=os.path.join(data_path,"train") #图片和保存TFRecoed的地址
self.test_path=os.path.join(data_path,"test") #使用TFRecord能加快数据输入的速度
self.TF_path=os.path.join(data_path,"TFRecordData")
self.tf_train_path=os.path.join(self.TF_path,"train")
self.tf_test_path=os.path.join(self.TF_path,"test")
self.classes=os.listdir(self.train_path)
self.__Makedirs()
self.train_batch_initializer=None #使用tf.data的迭代器需要显示的初始化
self.test_batch_initializer=None
self.__CreateTFRecord(self.train_path,self.tf_train_path)
self.__CreateTFRecord(self.test_path,self.tf_test_path)
def __CreateTFRecord(self,read_path,save_path):
path=os.path.join(save_path,"data.TFRecord")
if os.path.exists(path):
print("find file "+(os.path.join(save_path,"data.TFRecords"))) #已存在就跳过
return
else:
print("cannot find file %s,ready to recreate"%(os.path.join(save_path,"data.TFRecords")))
writer=tf.python_io.TFRecordWriter(path=path)
image_path=[]
image_label=[]
image_size=[int(self.img_size[0]*1.5),int(self.img_size[1]*1.5)] #数据增强会进行随机剪裁,这里先放大图片
for label,class_name in enumerate(self.classes):
class_path=os.path.join(read_path,class_name)
for image_name in os.listdir(class_path):
image_path.append(os.path.join(class_path,image_name))
image_label.append(label)
for i in range(5):image_path,image_label=shuffle(image_path,image_label)
for i in tqdm(range(len(image_path)),desc="TFRecord"):
image,label=Image.open(image_path[i]).resize(image_size,Image.BICUBIC),image_label[i]
image=image.convert("RGB")
image=image.tobytes()
example=tf.train.Example(features=tf.train.Features(feature={
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}))
writer.write(example.SerializeToString())
writer.close()
def __Makedirs(self):#文件夹不存在就创建它
if not os.path.exists(self.TF_path):
os.makedirs(self.TF_path)
if not os.path.exists(self.tf_train_path):
os.makedirs(self.tf_train_path)
if not os.path.exists(self.tf_test_path):
os.makedirs(self.tf_test_path)
def __parsed(self,tensor):#对tensor进行解码得到图片,剪裁和标准化
raw_image_size=[int(self.img_size[0]*1.5),int(self.img_size[1]*1.5),3]
feature=tf.parse_single_example(tensor,features={
"image":tf.FixedLenFeature([],tf.string),
"label":tf.FixedLenFeature([],tf.int64)
})
image=tf.decode_raw(feature["image"],tf.uint8)
image=tf.reshape(image,raw_image_size)
image=tf.random_crop(image,self.img_size)
image=tf.image.per_image_standardization(image)
label=tf.cast(feature["label"],tf.int32)
label=tf.one_hot(label,10)
return image,label
def __parsed_distorted(self,tensor):#加上数据增强部分
raw_image_size=[int(self.img_size[0]*1.5),int(self.img_size[1]*1.5),3]
feature=tf.parse_single_example(tensor,features={
"image":tf.FixedLenFeature([],tf.string),
"label":tf.FixedLenFeature([],tf.int64)
})
image=tf.decode_raw(feature["image"],tf.uint8)
image=tf.reshape(image,raw_image_size)
image=tf.random_crop(image,self.img_size)
image=tf.image.random_flip_left_right(image)
image=tf.image.random_flip_up_down(image)
image=tf.image.random_brightness(image,max_delta=0.4)
image=tf.image.random_hue(image,max_delta=0.4)
image=tf.image.random_contrast(image,lower=0.7,upper=1.3)
image=tf.image.random_saturation(image,lower=0.7,upper=1.3)
image=tf.image.per_image_standardization(image)
label=tf.cast(feature["label"],tf.int32)
label=tf.one_hot(label,10)
return image,label
def __GetBatchIterator(self,path,parsed,batch_size):#得到next_batch和initializer
filename=[os.path.join(path,name)for name in os.listdir(path)]
dataset=tf.data.TFRecordDataset(filename)
dataset=dataset.map(parsed)
dataset=dataset.shuffle(buffer_size=500)
dataset=dataset.batch(batch_size)
dataset=dataset.repeat(None)
iterator=dataset.make_initializable_iterator()
return iterator.initializer,iterator.get_next()
'''
tf.data里面对dataset有一些优化,速度会快很多,但是版太低可能没有这几个API
def __GetBatchIterator(self,path,parsed,batch_size):
filename=[os.path.join(path,name)for name in os.listdir(path)]
dataset=tf.data.TFRecordDataset(filename)
dataset=dataset.prefetch(tf.contrib.data.AUTOTUNE)
dataset=dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=1000,count=None,seed=233))
dataset=dataset.apply(tf.data.experimental.map_and_batch(parsed,batch_size))
dataset=dataset.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
iterator=dataset.make_initializable_iterator()
return iterator.initializer,iterator.get_next()
'''
def __detail(self,path):
Max=-1e9
Min=1e9
print("train dataset:")
path=[os.path.join(path,name)for name in self.classes]
for i in range(len(self.classes)):
num=len(os.listdir(path[i]))
print("%-12s:%3d"%(self.classes[i],num))
Max=max(Max,num)
Min=min(Min,num)
print("max:%d min:%d"%(Max,Min))
def detail(self):#统计dataset的一些信息
self.__detail(self.train_path)
self.__detail(self.test_path)
def global_variables_initializer(self):#initializer放在一起
initializer=[]
initializer.append(self.train_batch_initializer)
initializer.append(self.test_batch_initializer)
initializer.append(tf.global_variables_initializer())
return initializer
def test_batch(self,batch_size):
self.test_batch_initializer,batch=self.__GetBatchIterator(self.tf_test_path,self.__parsed,batch_size)
return batch
def train_batch(self,batch_size):
self.train_batch_initializer,batch=self.__GetBatchIterator(self.tf_train_path,self.__parsed_distorted,batch_size)
return batch
在使用的时候就会方便很多
data_path="D:\\log\\cifar10\\Image"
data=DatasetReader(data_path,image_size=[128,128])
train_batch=data.train_batch(batch_size=100)
test_batch=data.test_batch(batch_size=100)
with tf.Session() as sess:
sess.run(data.global_variables_initalizer())
image,label=sess.run(train_batch)
sess.run(train_op,feed_dict={x:image,y:label,training:True})