总是跑cifar10 mnist啥的,其他数据集还是跑一跑把
闲着没事把数据读取部分重写了一遍
代码比以前要更加通用一点
考虑到大部分数据集都可以有一下的储存形式
(我把cifar10图片也像这样分类了)
写了一个较为通用的数据读取和预处理的class
在处理之前会转化为TFRecord并放在TFRecord文件夹内,这个步骤会对图片resize(本地图片大小随意)
图片统一mode为RGB(此处应有故事…各种model的储存方式不同)
采用多线程读入数据集,并自动为每个类别分配标签,最后分batch输出
可以选择是否进行数据增强
可以设置图片固定的输出大小
当然还是搞了一些提示语句(英语不好,瞎写了一些)
代码在这儿:
希望能帮到大家
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.utils import shuffle
import os
class DataReader():
def __init__(self,data_path,enhance=True,output_size=None,min_after_dequeue=350,num_threads=1):
# data_path:数据集文件夹,文件夹内应有train和test文件夹
#enhance是否数据增强
#ourpur_size 输出图片大小
#min_after_dequeue 队列最少数量
#num_threads 读取文件的线程数
self.num_threads=num_threads
self.enhance=enhance
self.min_after_dequeue=min_after_dequeue
self.output_size=output_size
self.train_data_path=os.path.join(data_path,"train")
self.test_data_path=os.path.join(data_path,"test")
self.tf_path=os.path.join(data_path,"TFRecord")
self.tf_train_path=os.path.join(self.tf_path,"train")
self.tf_test_path=os.path.join(self.tf_path,"test")
self.classes=os.listdir(self.train_data_path)
if not os.path.exists(self.tf_path):os.makedirs(self.tf_path) #创建好TFRecord的保存路径
if not os.path.exists(self.tf_train_path):os.makedirs(self.tf_train_path)
if not os.path.exists(self.tf_test_path):os.makedirs(self.tf_test_path)
self.detail()
def class_list(self):#标签对应类别
return self.classes
def Batcher(self,batch_size,path,distorted):
filename=os.listdir(path)
for i in range(len(filename)):filename[i]=os.path.join(path,filename[i])
filename_queue=tf.train.string_input_producer(filename,shuffle=True,num_epochs=None)
reader=tf.TFRecordReader()
_,serialized_examples=reader.read(filename_queue)
feature=tf.parse_single_example(serialized_examples,features={
"image":tf.FixedLenFeature([],tf.string),
"label":tf.FixedLenFeature([],tf.int64)
})
if self.enhance:image_size=[int(self.output_size[0]*1.3),int(self.output_size[1]*1.3),3] #有数据增强的话需要随机剪裁
else:image_size=[self.output_size[0],self.output_size[1],3]
image=tf.decode_raw(feature["image"],tf.uint8)
image=tf.reshape(image,image_size)
image=tf.random_crop(image,[self.output_size[0],self.output_size[1],3])
if self.enhance and distorted: #训练集和测试集都要随机剪裁,不过不用后面的步骤
image=tf.image.random_flip_left_right(image)
image=tf.image.random_flip_up_down(image)
image=tf.image.random_brightness(image,max_delta=0.3)
image=tf.image.random_contrast(image,lower=0.7,upper=1.3)
image=tf.image.random_saturation(image,lower=0.7,upper=1.3)
image=tf.image.random_hue(image,max_delta=0.4)
image=tf.image.per_image_standardization(image)
label=tf.cast(feature["label"],tf.int32)
images,labels=tf.train.shuffle_batch([image,label],
num_threads=self.num_threads,
batch_size=batch_size,
min_after_dequeue=self.min_after_dequeue,
capacity=self.min_after_dequeue+batch_size*3)
labels=tf.one_hot(labels,2)
return images,labels
def train_batch(self,batch_size):
file_path=os.path.join(self.tf_train_path,"data.TFRecord")
if not os.path.exists(file_path):
print("file:"+file_path+" cannot be found,ready to create")
self.CreatTFRecordData(self.train_data_path,self.tf_train_path)
else:
print("find file:"+file_path)
return self.Batcher(batch_size,self.tf_train_path,True)
def test_batch(self,batch_size):
file_path=os.path.join(self.tf_test_path,"data.TFRecord")
if not os.path.exists(file_path):
print("file:"+file_path+" cannot be found\nready to create")
self.CreatTFRecordData(self.test_data_path,self.tf_test_path)
else:
print("find file:"+file_path)
return self.Batcher(batch_size,self.tf_test_path,False)
def detail(self):
self.num_examples_for_train=0
self.num_examples_for_test=0
print("dataset detail:")
print("train_data:")
for label,image_class in enumerate(self.classes):
class_path=os.path.join(self.train_data_path,image_class)
print("class:%s num:%d"%(image_class,len(os.listdir(class_path))))
self.num_examples_for_train+=len(os.listdir(class_path))
print("test_data")
for label,image_class in enumerate(self.classes):
class_path=os.path.join(self.test_data_path,image_class)
print("class:%s num:%d"%(image_class,len(os.listdir(class_path))))
self.num_examples_for_test+=len(os.listdir(class_path))
def CreatTFRecordData(self,data_path,save_path):
print("extracting from "+data_path+"\\*.jpg")
image_list=[]
label_list=[]
image_total=0
if self.enhance:
shape=(int(self.output_size[0]*1.3),int(self.output_size[1]*1.3))
else:
shape=self.output_size
for label,image_class in enumerate(self.classes):
class_path=os.path.join(data_path,image_class)
for image_name in os.listdir(class_path):
image_path=os.path.join(class_path,image_name)
image_list.append(image_path)
label_list.append(label)
image_total+=1
for i in range(3):image_list,label_list=shuffle(image_list,label_list)
writer=tf.python_io.TFRecordWriter(os.path.join(save_path,"data.TFRecord"))
for i in range(image_total):
if (i+1)%2500==0:print("Create "+os.path.join(save_path,"data.TFRecord")+": %.1f %%"%(i*100/image_total))
image=Image.open(image_list[i]).resize(shape,Image.BICUBIC)
image=image.convert("RGB")
image=image.tobytes()
example=tf.train.Example(features=tf.train.Features(feature={
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label_list[i]]))
}))
writer.write(example.SerializeToString())
writer.close()
print("finish")
'''
cifar10=DataReader(data_path,output_size=(64,64),enhance=True)
cifar10.detail() #各个类别数量统计
classes=cifar10.class_list() #等于cifar10.classes 得到标签对应的类别名称
train_batch=cifar10.train_batch(batch_size=100) #训练数据batch
test_batch=cifar10.test_batch(batch_size=100) #测试数据batch
'''
显示图片:
def imshows(classes,images,labels,index,amount,predictions=None):
#classes 类别数组
#image 图片数组
#labels 标签数组
#index amount 从数组第index开始输出amount张照片
#prediction 预测结果
fig=plt.gcf()
fig.set_size_inches(10,20)#大小看怎么调整合适把
for i in range(amount):
title="lab:"+classes[np.argmax(labels[index+i])]
if predictions is not None:
title=title+"prd:"+name[np.argmax(predictions[index+i])]
ax=plt.subplot(5,6,i+1)#每行五个,输出6行
ax.set_title(title)
ax.imshow(images[index+i])
plt.show()
测试:
data_path="D:\\log\\cifar10\\Image"
batch_size=100
cifar10=DataReader(data_path,output_size=(64,64),enhance=True)
cifar10.detail()
train_batch=cifar10.train_batch(batch_size=batch_size)
test_batch=cifar10.test_batch(batch_size=batch_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(10):
img,lab=sess.run(train_batch)
imshows(cifar10.classes,img,lab,0,5)
coord.request_stop()
coord.join(threads)