tfecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义:
message Example {
Features features = 1;
};
message Features{
map featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。
先给出数据和程序的链接
数据集 : YaleB_dataset
处理程序 : Batch Generator.py
使用queue读取图片数据方法的大致思路分为三步:
1、根据数据集的具体存储情况生成一个txt清单,清单上记载了每一张图片的存储地址还有一些相关信息(如标签、大小之类的)
2、根据第一步的清单记录,读取数据和信息,并将这些数据和信息按照一定的格式写成Tensorflow的专用文件格式(.tfrecords)
3、从.tfrecords文件中批量的读取数据供给模型使用
具体情况如下:
这里第一张图片的的Class01表示的是第一个类别,00000表示的是第一个类别里的第一张,生成清单的程序如下:
##相关库函数导入
import os
import cv2 as cv
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
def getTrianList():
root_dir = "/Users/zhuxiaoxiansheng/Desktop/doc/SICA_data/YaleB" #数据存储文件夹地址
with open('/Users/zhuxiaoxiansheng/Desktop'+"/Yaledata.txt","w") as f: #txt文件生成地址
for file in os.listdir(root_dir):
if len(file) == 23: #图片名长为23个字节,避免读入其他的文件
f.write(root_dir+'/'+file+" "+ file[11:13] +"\n") #file[11:13]表示类别编号
生成的清单文件是这样的
在得到txt清单文件以后,根据这份文件就可以进入流程式的步骤了,首先我们需要生成.tfrecords文件,代码如下
def load_file(example_list_file): #从清单中读取地址和类别编号,这里的输入是清单存储地址
lines = np.genfromtxt(example_list_file,delimiter=" ",dtype=[('col1', 'S120'), ('col2', 'i8')])
examples = []
labels = []
for example,label in lines:
examples.append(example)
labels.append(label)
return np.asarray(examples),np.asarray(labels),len(lines)
def trans2tfRecord(trainFile,output_dir): #生成tfrecords文件
_examples,_labels,examples_num = load_file(trainFile)
filename = output_dir + '.tfrecords'
writer = tf.python_io.TFRecordWriter(filename)
for i,[example,label] in enumerate(zip(_examples,_labels)):
example = example.decode("UTF-8")
image = cv.imread(example)
image = cv.resize(image,(192,168)) #这里的格式需要注意,一定要尽量保证图片的大小一致
image_raw = image.tostring() #将图片矩阵转化为字符串格式
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
writer.close() #写入完成,关闭指针
return filename #返回文件地址
这里生成的是.tfrecords不好打开,就不展示了
设置从tfrecords文件中读取文件方式的函数如下:
def read_tfRecord(file_tfRecord): #输入是.tfrecords文件地址
queue = tf.train.string_input_producer([file_tfRecord])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([], tf.int64)
}
)
image = tf.decode_raw(features['image_raw'],tf.uint8)
image = tf.reshape(image,[192,168,3])
image = tf.cast(image, tf.float32)
image = tf.image.per_image_standardization(image)
label = tf.cast(features['label'], tf.int64) 这里设置了读取信息的格式
return image,label
上面就是主要的代码,其中特别要注意的就是以下两句,非常重要:
coord=tf.train.Coordinator() #创建一个协调器,管理线程
threads=tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队
这两句实现的功能就是创建线程并使用QueueRunner对象来提取数据。简单来说:使用tf.train函数添加QueueRunner到tensorflow中。在运行任何训练步骤之前,需要调用tf.train.start_queue_runners函数,否则tensorflow将一直挂起。
tf.train.start_queue_runners 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个tf.train.Coordinator,这样可以在发生错误的情况下正确地关闭这些线程。如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。if __name__ == '__main__':
getTrianList()
dataroad = "/Users/zhuxiaoxiansheng/Desktop/Yaledata.txt"
outputdir = "/Users/zhuxiaoxiansheng/Desktop/Yaledata"
trainroad = trans2tfRecord(dataroad,outputdir)
traindata,trainlabel = read_tfRecord(trainroad)
image_batch,label_batch = tf.train.shuffle_batch([traindata,trainlabel],
batch_size=100,capacity=2000,min_after_dequeue = 1000)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord = coord)
train_steps = 10
try:
while not coord.should_stop(): # 如果线程应该停止则返回True
example,label = sess.run([image_batch,label_batch])
print(example.shape,label)
train_steps -= 1
print(train_steps)
if train_steps <= 0:
coord.request_stop() # 请求该线程停止
except tf.errors.OutOfRangeError:
print ('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop. 请求该线程停止
coord.request_stop()
# And wait for them to actually do it. 等待被指定的线程终止
coord.join(threads)
如果成功的话会有下面的输出(输出结果就截自己的图吧):