tf.train.batch的第一个参数 tensors 传入的是tenor列表或者字典,返回的是tensor列表或字典
注意在tfrecord文件中存储的是一序列,并没有形状 所以需要reshape( [ ] )
推荐:介绍tfrecord两种方法说的比较好 https://www.cnblogs.com/polly333/p/7489699.html
一 TFRecords
TFRecords是一种将图像数据和标签放在一起的二进制文件,能够快速的复制、移动、读取、存储等操作,优点:可以更好的利用内存空间,缺点:生成过程比较耗费时间,特别是数据量比较大时;
TFRecords生成:用 tf.train.Eample的缓冲协议 ( protocol buffer ) 包含字段Features 获取数据;接下来,将其填充Example的协议缓冲区( protocol buffer );接着将缓冲区序列化为字符串,最后通过tf.python_io.TFRecordWriter将序列化后字符串写入tfrecords文件中。
TFRecords解码:获取 TFRecords数据后,利用 tf.TFRecordReader和tf.sparse_single_example解码器,将Example中的缓冲协议中的内容解析为Tensor张量。
reference: https://blog.csdn.net/u010911921/article/details/70991194
二 生成TFRecords
硬盘数据集---> image_list(图片路径) label_list(样本标签)--->生成.TFRecords(利用tf.train.Example
和tf.train.Features
这两个函数来构建一个example
然后将其序列化到文件中)--->.TFRecords---解码decode(tf,TFRecordReader()产生的reader来读取,接着将其读取到的内容,用tf.sparse_single_example函数,用tf.decode.raw() ---> )--->
首先从数据集中,把图片路径放置到一个image_list,样本标签放置到label_list中。
# --*-- encoding:utf-8 --*--
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import skimage.io as io
def get_file(file_dir):
"""
get full image directory and correspond labels
:param file_dir:
:return:
"""
images =[]
temp =[]
for root ,sub_folders,files in os.walk(file_dir):
#image directories
for name in files:
images.append(os.path.join(root,name))
#get 10 sub-folder names
for name in sub_folders:
temp.append(os.path.join(root,name))
labels =[]
for one_folder in temp:
n_img = len(os.listdir(one_folder))
letter = one_folder.split('/')[-1]
if letter =='A':
labels = np.append(labels,n_img*[1])
elif letter =="B":
labels = np.append(labels,n_img*[2])
elif letter =='C':
labels = np.append(labels,n_img*[3])
elif letter =="D":
labels = np.append(labels,n_img*[4])
elif letter =="E":
labels = np.append(labels,n_img*[5])
elif letter =="F":
labels = np.append(labels,n_img*[6])
elif letter =="G":
labels = np.append(labels,n_img*[7])
elif letter =="H":
labels = np.append(labels,n_img*[8])
elif letter =="I":
labels =np.append(labels,n_img*[9])
else:
labels = np.append(labels,n_img*[10])
#shuffle
temp = np.array([images,labels])
temp = temp.transpose()
np.random.shuffle(temp)
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(float(i)) for i in label_list]
return image_list,label_list
其次,当取得image_list
和label_list
以后,读取图片数据,然后利用tf.train.Example
和tf.train.Features
这两个函数来构建一个example
然后将其序列化到文件中(.TFRecord)。基本上就是一个Example
中包含Features
,Features
中包含Feature
字典,Feature
字典是由float_list
、bytes_List
或int64_list
等构成。
# 生成TFRecords---灰度图
#将label转换成int64类型,为了构建tf.train.Feature
def int64_feature(value):
if not isinstance(value,list):
value = [value]
return tf.train.Feature(int64_list = tf.train.Int64List(value=value))
#将image转换成bytes类型,同样也是为了构建tf.train.Feature
def bytes_feature(value):
return tf.train.Feature(bytes_list= tf.train.BytesList(value=[value]))
def convert_to_tfrecord(images,labels,save_dir,name):
"""
convert all images and labels to one tfrecord file
:param images:
:param labels:
:param save_dir:
:param name:
:return:
"""
filename = os.path.join(save_dir,name+".tfrecords")
n_samples = len(labels)
if np.shape(images)[0] != n_samples:
raise ValueError('Image size %d does not '
'match label size %d'%(images.shape[0],n_samples))
#wait some time
writer = tf.python_io.TFRecordWriter(filename)
print("\n Transform start....")
for i in np.arange(0,n_samples):
try:
image = io.imread(images[i])
image_raw = image.tostring()
label= int(labels[i])
example = tf.train.Example(features =tf.train.Features(feature= {'label':int64_feature(label),
"image_raw":bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
except IOError as e:
print("could not read :",images[i])
print("error:%s"%e)
print('Skip it')
writer.close()
print("Transform done!")
# 生成TFRecord---彩色图
def data_to_tfrecord(images, labels, filename): # images中存储的是所有图像路径的一个列表
""" Save data into TFRecord """ # labels是images中每个图像对应的标签
if os.path.isfile(filename): # filename是tfrecord文件名称
print("%s exists" % filename)
return
print("Converting data into %s ..." % filename)
writer = tf.python_io.TFRecordWriter(filename)
for index, img_file in zip(labels, images):
img1 = Image.open(img_file) # 通过PIL包中的Images函数读取、解码图片
width, height = img1.size # 获取图像的宽、高参数
img_raw = img1.tobytes() # 将图像转换成二进制序列
label = int(index) # 图片对应的标签
example = tf.train.Example(
features=tf.train.Features(
feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), # 保存标签
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), # 保存二进制序列
'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), # 保存图像的宽度
'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])) # 保存图像的高
}
)
)
writer.write(example.SerializeToString()) # Serialize To String
writer.close()
通过上述程序完成TFRecords文件, 但这个过程会花费较长时间。
TFRecords解码
读取一个文件采用队列(queue)形式;首先,生成一个文件名层队列,然后利用tf,TFRecordReader()产生的reader来读取,接着将其读取到的内容,用tf.sparse_single_example函数将image_raw及label读取以及分离开来,为以后操作做准备。
# 灰度图解码decode ------ tf.reshape(img, [28,28])
def read_and_decode(tfrecords_file,batch_size):
filename_queue = tf.train.string_input_producer([tfrecords_file])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(serialized_example,
features={"label":tf.FixedLenFeature([],tf.int64),
"image_raw":tf.FixedLenFeature([],tf.string),})
image = tf.decode_raw(img_features['image_raw'],tf.uint8)
################################################################
#
#put dataaugmentation here
################################################################
image = tf.reshape(image,[28,28]) # 待修改
label = tf.cast(img_features['label'],tf.int32)
image_batch, label_batch = tf.train.batch([image,label],
batch_size = batch_size,
num_threads = 64,
capacity=2000)
return image_batch,tf.reshape(label_batch,[batch_size])
# 彩色图像解码decode ------ tf.reshape(img, [height, width, 3])
import numpy as np
import tensorflow as tf
import tensorlayer as tl
def read_and_decode(filename):
""" Return tensor to read from TFRecord """
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example, features={
'label': tf.FixedLenFeature([], tf.int64), # 从tfrecord文件中读取各种信息
'img_raw': tf.FixedLenFeature([], tf.string),
'img_width': tf.FixedLenFeature([], tf.int64),
'img_height': tf.FixedLenFeature([], tf.int64)
}
)
# You can do more image distortion here for training data
width = tf.cast(features['img_width'], tf.int32) # 转型
height = tf.cast(features['img_height'], tf.int32)
img = tf.decode_raw(features['img_raw'], tf.uint8) # 从二进制文件转成uint8
img = tf.reshape(img, [height, width, 3]) # 对图像进行reshape,注意在tfrecord文件中存储的是一序列,并没有形状
img = tf.image.resize_images(img, [32, 32]) # 将图像统一到同一尺寸
# img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int32)
return img, label
# Example to visualize data
img, label = read_and_decode("train.tfrecord")
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=4,
capacity=5000,
min_after_dequeue=100,
num_threads=1)
print("img_batch : %s" % img_batch._shape)
print("label_batch : %s" % label_batch._shape)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(3): # number of mini-batch (step)
print("Step %d" % i)
val, l = sess.run([img_batch, label_batch])
# exit()
print(val.shape, l)
tl.visualize.images2d(val, second=1, saveable=False, name='batch'+str(i), dtype=np.uint8, fig_idx=2020121)
coord.request_stop()
coord.join(threads)
reference : https://blog.csdn.net/qq_37541097/article/details/80218111
显示TFRecord文件中的图像
由于tf.train()函数在graph中增加了tf.train.QueueRunner类(在线程中运行线程中的队列数据),tf.train.start_queue_runner启动所有graph中的线程;用tf.train.Coordinator来管理线程(启动多少线程 何时终止线程...)
# initialize global & local variables
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
# create a coordinate and run queue runner objects
# 启动多线程处理数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(3):
batch_images, batch_labels = sess.run([images, labels])
for i in range(10):
plt.imshow(batch_images[i, ...])
plt.show()
print "Current image label is: ", batch_lables[i]
# close threads 结束线程
coord.request_stop()
coord.join(threads)
sess.close()
如何显示xxx.tfrecords文件中的图片
tfrecords_file = 'E:/Anaconda3/tensorflow//dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file,Batch_size)
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
# just plot one batch size
image, label = sess.run([image_batch, label_batch])
for j in np.arange(4):
print('label: %d' % label[j])
plt.imshow(image[j,:,:,:])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
扩展:
tf.train.slice_input_prodcer 和 tf.train.batch深度解析
tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算。具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。
tf在内存队列之前,还设立了一个文件名队列,文件名队列存放的是参与训练的文件名,要训练 N个epoch,则文件名队列中就含有N个批次的所有文件名。 示例图如下:
在N个epoch的文件名最后是一个结束标志,当tf读到这个结束标志的时候,会抛出一个 OutofRange 的异常,外部捕获到这个异常之后就可以结束程序了。而创建tf的文件名队列就需要使用到 tf.train.slice_input_producer 函数。
tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机(shuffle),抽取出一个tensor放入文件名队列。
slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。
tf.train.batch是一个tensor队列生成器,作用是按照给定的tensor顺序,把batch_size个tensor推送到内存队列,作为训练一个batch的数据,等待tensor出队执行计算。
batch(tensors, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None)
reference:
tf.train.slice_input_producer 很好的解释 https://blog.csdn.net/weixin_42052460/article/details/80714379
两种TFRecord文件,并读取 https://blog.csdn.net/qq_37541097/article/details/80218111
1). 利用常用图像处理库读取图像并解码,转换成二进制文件进行存储,网络上找到的基本上都是这种方式
img = tf.decode_raw(features['img_raw'], tf.uint8) # 从二进制文件转成uint8
img = tf.reshape(img, [height, width, 3]) # 对图像进行reshape,注意在tfrecord文件中存储的是一序列,并没有形状
2). 利用tf.gfile.FastGFile读取图像信息(貌似并没有解码),转换成二进制文件存储。
img = tf.image.decode_jpeg(features['img_raw']) # 与方式一的不同点在于需要用decode_jpeg解码
生成多个.TFRecord文件 --- multi
https://blog.csdn.net/fu6543210/article/details/80263425
解码decode后进行resize
https://blog.csdn.net/qq_37541097/article/details/80218111
TFRecords是tensorflow官网提供的一种二进制文件(),它能方便的进行数据复制 移动 和更好的利用内存,同时不需要单独的标签文件(在读取数据文件是自动添加标签,下面有介绍);在训练时,使用TFRecords中数据的流程:首先生成xxx.tfrecord文件,接着使用input pipeline读取xxx.tfrecords文件/其他支持格式,then读取并解码数据,随机乱序(shuffle),生成文件序列(batch);最后输入到模型中。
如果有一串jpg图片地址和相应的标签:Images
和 Labels
1. 生成TFRecords
存入TFRecords文件需要数据先存入名为example的protocol buffer,然后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64;具体来说,TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。
message Example {
Features features = 1;
};
message Features {
map feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
cwd = "E:/Anaconda3/tensorflow/Dataset/data/"
classes = {'cats', 'dogs'} #预先自己定义的类别
#将数据转化TFRecord文件对应的属性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
# 开始将数据写入TFRecord文件(xxx.tfrecord)
train_filename = 'tensorflow/train.tfrecords' # 输出文件地址
# 创建一个writer来写TFRecords文件(写TFRecords <==> 输出TFRecords文件)
writer = tf.python_io.TFRecordWriter(train_filename) #输出成tfrecord文件
for index, name in enumerate(classes): # 从classes中自动获取类别 (label)
class_path = cwd + name + '//'
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每张picture的绝对地址
img = Image.open(img_path)
img = img.resize((640, 320))
img_raw = img.tobytes() #将图片转化为二进制格式
# 创建一个属性(feature)
example = tf.train.Example(features = tf.train.Features(feature = {
"label":_int64_feature(index),
"img_raw":_bytes_feature(img_raw),
}))
# 将上面的example protocol buffer 写入文件
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
输入: 数据文件路径 path
输出: xxx.tfrecords文件
reference:https://blog.csdn.net/hjxu2016/article/details/76165559
2. TFRecord 文件---解码decode
(1). 用tf.train.string_input_producer 读取tfrecords文件(xxx.tfrecords)的list建立文件名队列(FIFO序列),同时,可以申明num_epoches和shuffle参数表示需要读取数据的次数以及时候将tfrecords文件读入顺序打乱;结果:图像路径list
(2). 定义TFRecordReader读取(1)中的序列(图像路径list)返回下一个record;结果:serialize example和feature字典
(3). 用tf.parse_string_example对读取的TFRecords文件进行解码,抽取((2) serialize example和feature字典)中,返回feature对应的值,此时对应的值都是string,需要经过tf.decode(...) 和 tf.cast(...)等操作,将string类型的图像数据还原原始图像;同时也可以进行一些preprocessing操作;
(4). 利用tf.train.shuffle_batch(...)和tf.train.batch(...)将(3)中还原原始图像生成batch图像序列
#读取文件
def read_and_decode(filename,batch_size):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [300, 300, 3]) #图像归一化大小
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #图像减去均值处理,根据自己的需要决定要不要加上
label = tf.cast(features['label'], tf.int32)
#特殊处理,去数据的batch,如果不要对数据做batch处理,也可以把下面这部分不放在函数里
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=200,
min_after_dequeue=150)
return img_batch, tf.reshape(label_batch,[batch_size])
在读取到队列中后,数据输出之前还要作解码的操作从,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量;
输入:XXX.tfrecords batch_size
输出: image_batch label_batch
3. 扩展
由于tf.train()函数在graph中增加了tf.train.QueueRunner类(在线程中运行线程中的队列数据),tf.train.start_queue_runner启动所有graph中的线程;用tf.train.Coordinator来管理线程(启动多少线程 何时终止线程...)
# initialize global & local variables
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
# create a coordinate and run queue runner objects
# 启动多线程处理数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(3):
batch_images, batch_labels = sess.run([images, labels])
for i in range(10):
plt.imshow(batch_images[i, ...])
plt.show()
print "Current image label is: ", batch_lables[i]
# close threads 结束线程
coord.request_stop()
coord.join(threads)
sess.close()
4. 如何显示xxx.tfrecords文件中的图片
tfrecords_file = 'E:/Anaconda3/tensorflow//dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file,Batch_size)
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
# just plot one batch size
image, label = sess.run([image_batch, label_batch])
for j in np.arange(4):
print('label: %d' % label[j])
plt.imshow(image[j,:,:,:])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
batch_size这里可以大家任意设定,显示几幅图片都可以,这里设置为6 同时i 控制显示张数
5. 完整代码
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
cwd = "E:/Anaconda3/tensorflow/dataset/data/"
classes = {'cats', 'dogs'}
writer = tf.python_io.TFRecordWriter('train.tfrecords')
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
for index, name in enumerate(classes):
class_path = cwd + name + '//'
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每张图片的绝对地址
img = Image.open(img_path)
img = img.resize((640, 320))
img_raw = img.tobytes() #将图片转化为二进制格式
example = tf.train.Example(features = tf.train.Features(feature = {
"label":_int64_feature(index),
"img_raw":_bytes_feature(img_raw),
}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
def read_and_decode(filename, batch_size): # read train.tfrecords
filename_queue = tf.train.string_input_producer([filename])# create a queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#return file_name and file
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' :tf.FixedLenFeature([],tf.string),
})#return image and label
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [208, 208, 3]) #reshape image to 512*80*3
label = tf.cast(features['label'], tf.int32) #throw label tensor
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=2000,
min_after_dequeue=1500,
)
return img_batch, tf.reshape(label_batch,[batch_size])
tfrecords_file = 'D:/Anaconda3/tensorflow/dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file, Batch_size)
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
# just plot one batch size
image, label = sess.run([image_batch, label_batch])
for j in np.arange(BATCH_SIZE):
print('label: %d' % label[j])
plt.imshow(image[j,:,:,:])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
6. 参考文献
1. https://blog.csdn.net/u012222949/article/details/72875281 有imageFile 和 labelFile, 将imageFile和 labelFile分成train_set test_set
2. https://blog.csdn.net/wiinter_fdd/article/details/72835939 imageFile_train + class{} 类别自动生成 + imageFile_test
3. https://blog.csdn.net/gybheroin/article/details/79800679 同上
4. http://www.cnblogs.com/arkenstone/p/7507261.html 结构特别清晰
5. https://www.cnblogs.com/Charles-Wan/p/6197019.html 读取数据分类清晰