我们训练文件夹的内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
分别是TFRecord生成器以及样本Example模块。
writer = tf.python_io.TFRecordWriter(record_path)
writer.write(tf_example.SerializeToString())
writer.close()
这里面writer就是我们TFrecord生成器。接着我们就可以通过writer.write(tf_example.SerializeToString())
来生成我们所要的tfrecord文件了。这里需要注意的是我们TFRecord生成器在写完文件后需要关闭writer.close()
。这里tf_example.SerializeToString()
是将Example中的map压缩为二进制文件,更好的节省空间。接下来讲述tf_example是如何生成。
Example协议块
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
我们可以看出上面的tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
(1)tf.train.Example(features = None)
这里的features是tf.train.Features类型的特征实例。
(2)tf.train.Features(feature = None)
这里的feature是以字典的形式存在,*key:要保存数据的名字 value:要保存的数据,但是格式必须符合tf.train.Feature实例要求。
以上参考:https://www.jianshu.com/p/b480e5fcb638
循环读取图片
import os
import tensorflow as tf
from PIL import Image
cwd = 'opt/pyproject/demo/TFRECORD/matlab//'
classes = {'noarm','onearm','run','twoarms'}
def create_record():
writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd +"/"+ name+"/"
for img_name in os.listdir(class_path): #已经将四种不同类型的图片分在了四个文件夹内
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((64, 64))
img_raw = img.tobytes() #将图片转化为原生bytes
print (index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
def write_test():
writer = tf.python_io.TFRecordWriter('test.tfrecord')
image=Images.open(cwd+'noarm_1.jpeg')
#image=image.resize([500,500])
image_data=image.tobytes()
index=0
# 创建 Example 对象,并且将 Feature 一一对应填充进去。
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image_data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}))
# 将 example 序列化成 string 类型,然后写入。
writer.write(example.SerializeToString())
writer.close()
write_test()
def _get_images_labels(input_file):
dataset=tf.data.TFRecordDataset(input_file)
dataset=dataset.map(_parse_record)
#dataset=dataset.prefetch(-1)
#dataset=dataset.repeat().batch(128)
iterator=dataset.make_one_shot_iterator()
images, labels=iterator.get_next()
return images, labels
def _parse_record(example_proto):
features = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image_data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}
parsed_features=tf.parse_single_example(example_proto, features=features)
img=tf.decode_raw(parsed_features['image_data],out_type=uint8)
#img=tf.reshape(img,shape=[500,500,3])
# 如果前面使用了image.resize([500,500])这里就需要还原为[500,500,3]
img=tf.reshape(img, shape=[656,875,3])
label=parsed_features['label']
#这里label不需要reshape
label=tf.cast(label, tf.in32)
return img, label
with tf.Session() as sess:
image, label = sess.run(_get_images_labels('test.tfrecord'))
plt.figure()
plt.imshow(image)
plt.show()
参考:https://blog.csdn.net/briblue/article/details/80789608
import os
import tensorflow as tf
from PIL import Image
cwd = 'E:/train_data/picture_dog//'
classes = {'husky','jiwawa'}
#制作TFRecords数据
def create_record():
writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd +"/"+ name+"/"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((64, 64))
img_raw = img.tobytes() #将图片转化为原生bytes
print (index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
#-------------------------------------------------------------------------
#读取二进制数据
def read_and_decode(filename):
# 创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.TFRecordReader()
# reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [64, 64, 3])
#img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(label, tf.int32)
return img, label
#--------------------------------------------------------------------------
#---------主程序----------------------------------------------------------
if __name__ == '__main__':
create_record()
batch = read_and_decode('dog_train.tfrecords')
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess: #开始一个会话
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(40):
example, lab = sess.run(batch)#在会话中取出image和label
img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
img.save(cwd+'/'+str(i)+'_Label_'+str(lab)+'.jpg')#存下图片;注意cwd后边加上‘/’
print(example, lab)
coord.request_stop()
coord.join(threads)
sess.close()
参考:https://blog.csdn.net/ywx1832990/article/details/78462582
整体参考:
读取并训练: https://my.oschina.net/u/3800567/blog/1637798?from=singlemessage
参考: https://www.cnblogs.com/puheng/p/9576521.html
tenforflow 官方文档:
https://tensorflow.google.cn/guide/datasets#parsing_tfexample_protocol_buffer_messages