训练数据转为tfrecord

转载自guoyunfei20专栏 《将voc数据集转换成.tfrecord格式供tensorflow训练用》
https://blog.csdn.net/guoyunfei20/article/details/80626040

把现成的东西捋了一遍,主要是我想训练自己想要的模型,格式和分类类别不一样。所以先把VOC捋顺了,为以后改成适合自己的做准备。此处记录一下。就一个文件搞定(现在啥都分布式,读起来费劲)。


  
    
    
    
    
  1. #coding=utf-8
  2. import os
  3. import sys
  4. import random
  5. import numpy as np
  6. import tensorflow as tf
  7. # process a xml file
  8. import xml.etree.ElementTree as ET
  9. DIRECTORY_ANNOTATIONS = 'Annotations/'
  10. DIRECTORY_IMAGES = 'JPEGImages/'
  11. RANDOM_SEED = 4242
  12. SAMPLES_PER_FILES = 2000
  13. VOC_LABELS = {
  14. 'none': ( 0, 'Background'),
  15. 'aeroplane': ( 1, 'Vehicle'),
  16. 'bicycle': ( 2, 'Vehicle'),
  17. 'bird': ( 3, 'Animal'),
  18. 'boat': ( 4, 'Vehicle'),
  19. 'bottle': ( 5, 'Indoor'),
  20. 'bus': ( 6, 'Vehicle'),
  21. 'car': ( 7, 'Vehicle'),
  22. 'cat': ( 8, 'Animal'),
  23. 'chair': ( 9, 'Indoor'),
  24. 'cow': ( 10, 'Animal'),
  25. 'diningtable': ( 11, 'Indoor'),
  26. 'dog': ( 12, 'Animal'),
  27. 'horse': ( 13, 'Animal'),
  28. 'motorbike': ( 14, 'Vehicle'),
  29. 'person': ( 15, 'Person'),
  30. 'pottedplant': ( 16, 'Indoor'),
  31. 'sheep': ( 17, 'Animal'),
  32. 'sofa': ( 18, 'Indoor'),
  33. 'train': ( 19, 'Vehicle'),
  34. 'tvmonitor': ( 20, 'Indoor'),
  35. }
  36. def int64_feature(values):
  37. """Returns a TF-Feature of int64s.
  38. Args:
  39. values: A scalar or list of values.
  40. Returns:
  41. a TF-Feature.
  42. """
  43. if not isinstance(values, (tuple, list)):
  44. values = [values]
  45. return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
  46. def float_feature(value):
  47. """Wrapper for inserting float features into Example proto.
  48. """
  49. if not isinstance(value, list):
  50. value = [value]
  51. return tf.train.Feature(float_list=tf.train.FloatList(value=value))
  52. def bytes_feature(value):
  53. """Wrapper for inserting bytes features into Example proto.
  54. """
  55. if not isinstance(value, list):
  56. value = [value]
  57. return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
  58. SPLIT_MAP = [ 'train', 'val', 'trainval']
  59. """
  60. Process a image and annotation file.
  61. Args:
  62. filename: string, path to an image file e.g., '/path/to/example.JPG'.
  63. coder: instance of ImageCoder to provide TensorFlow image coding utils.
  64. Returns:
  65. image_buffer: string, JPEG encoding of RGB image.
  66. height: integer, image height in pixels.
  67. width: integer, image width in pixels.
  68. 读取一个样本图片及对应信息
  69. """
  70. def _process_image(directory, name):
  71. # Read the image file.
  72. filename = os.path.join(directory, DIRECTORY_IMAGES, name + '.jpg')
  73. image_data = tf.gfile.FastGFile(filename, 'r').read()
  74. # Read the XML annotation file.
  75. filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
  76. tree = ET.parse(filename)
  77. root = tree.getroot()
  78. # Image shape.
  79. size = root.find( 'size')
  80. shape = [int(size.find( 'height').text), int(size.find( 'width').text), int(size.find( 'depth').text)]
  81. # Find annotations.
  82. # 获取每个object的信息
  83. bboxes = []
  84. labels = []
  85. labels_text = []
  86. difficult = []
  87. truncated = []
  88. for obj in root.findall( 'object'):
  89. label = obj.find( 'name').text
  90. labels.append(int(VOC_LABELS[label][ 0]))
  91. labels_text.append(label.encode( 'ascii'))
  92. if obj.find( 'difficult'):
  93. difficult.append(int(obj.find( 'difficult').text))
  94. else:
  95. difficult.append( 0)
  96. if obj.find( 'truncated'):
  97. truncated.append(int(obj.find( 'truncated').text))
  98. else:
  99. truncated.append( 0)
  100. bbox = obj.find( 'bndbox')
  101. bboxes.append((float(bbox.find( 'ymin').text) / shape[ 0],
  102. float(bbox.find( 'xmin').text) / shape[ 1],
  103. float(bbox.find( 'ymax').text) / shape[ 0],
  104. float(bbox.find( 'xmax').text) / shape[ 1]
  105. ))
  106. return image_data, shape, bboxes, labels, labels_text, difficult, truncated
  107. """
  108. Build an Example proto for an image example.
  109. Args:
  110. image_data: string, JPEG encoding of RGB image;
  111. labels: list of integers, identifier for the ground truth;
  112. labels_text: list of strings, human-readable labels;
  113. bboxes: list of bounding boxes; each box is a list of integers;
  114. specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
  115. to the same label as the image label.
  116. shape: 3 integers, image shapes in pixels.
  117. Returns:
  118. Example proto
  119. 将一个图片及对应信息按格式转换成训练时可读取的一个样本
  120. """
  121. def _convert_to_example(image_data, labels, labels_text, bboxes, shape, difficult, truncated):
  122. xmin = []
  123. ymin = []
  124. xmax = []
  125. ymax = []
  126. for b in bboxes:
  127. assert len(b) == 4
  128. # pylint: disable=expression-not-assigned
  129. [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
  130. # pylint: enable=expression-not-assigned
  131. image_format = b'JPEG'
  132. example = tf.train.Example(features=tf.train.Features(feature={
  133. 'image/height': int64_feature(shape[ 0]),
  134. 'image/width': int64_feature(shape[ 1]),
  135. 'image/channels': int64_feature(shape[ 2]),
  136. 'image/shape': int64_feature(shape),
  137. 'image/object/bbox/xmin': float_feature(xmin),
  138. 'image/object/bbox/xmax': float_feature(xmax),
  139. 'image/object/bbox/ymin': float_feature(ymin),
  140. 'image/object/bbox/ymax': float_feature(ymax),
  141. 'image/object/bbox/label': int64_feature(labels),
  142. 'image/object/bbox/label_text': bytes_feature(labels_text),
  143. 'image/object/bbox/difficult': int64_feature(difficult),
  144. 'image/object/bbox/truncated': int64_feature(truncated),
  145. 'image/format': bytes_feature(image_format),
  146. 'image/encoded': bytes_feature(image_data)}))
  147. return example
  148. """
  149. Loads data from image and annotations files and add them to a TFRecord.
  150. Args:
  151. dataset_dir: Dataset directory;
  152. name: Image name to add to the TFRecord;
  153. tfrecord_writer: The TFRecord writer to use for writing.
  154. """
  155. def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
  156. image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
  157. _process_image(dataset_dir, name)
  158. example = _convert_to_example(image_data,
  159. labels,
  160. labels_text,
  161. bboxes,
  162. shape,
  163. difficult,
  164. truncated)
  165. tfrecord_writer.write(example.SerializeToString())
  166. """
  167. 以VOC2012为例,下载后的文件名为:VOCtrainval_11-May-2012.tar,解压后
  168. 得到一个文件夹:VOCdevkit
  169. voc_root就是VOCdevkit文件夹所在的路径
  170. 在VOCdevkit文件夹下只有一个文件夹:VOC2012,所以下边参数year该文件夹的数字部分。
  171. 在VOCdevkit/VOC2012/ImageSets/Main下存放了20个类别,每个类别有3个的txt文件:
  172. *.train.txt存放训练使用的数据
  173. *.val.txt存放测试使用的数据
  174. *.trainval.txt是train和val的合集
  175. 所以参数split只能为'train', 'val', 'trainval'之一
  176. """
  177. def run(voc_root, year, split, output_dir, shuffling=False):
  178. # 如果output_dir不存在则创建
  179. if not tf.gfile.Exists(output_dir):
  180. tf.gfile.MakeDirs(output_dir)
  181. # VOCdevkit/VOC2012/ImageSets/Main/train.txt
  182. # 中存放有所有20个类别的训练样本名称,共5717个
  183. split_file_path = os.path.join(voc_root, 'VOC%s'%year, 'ImageSets', 'Main', '%s.txt'%split)
  184. print '>> ', split_file_path
  185. with open(split_file_path) as f:
  186. filenames = f.readlines()
  187. # shuffling == Ture时,打乱顺序
  188. if shuffling:
  189. random.seed(RANDOM_SEED)
  190. random.shuffle(filenames)
  191. # Process dataset files.
  192. i = 0
  193. fidx = 0
  194. dataset_dir = os.path.join(voc_root, 'VOC%s'%year)
  195. while i < len(filenames):
  196. # Open new TFRecord file.
  197. tf_filename = '%s/%s_%03d.tfrecord' % (output_dir, split, fidx)
  198. with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
  199. j = 0
  200. while i < len(filenames) and j < SAMPLES_PER_FILES:
  201. sys.stdout.write( '\r>> Converting image %d/%d' % (i+ 1, len(filenames)))
  202. sys.stdout.flush()
  203. filename = filenames[i].strip()
  204. _add_to_tfrecord(dataset_dir, filename, tfrecord_writer)
  205. i += 1
  206. j += 1
  207. fidx += 1
  208. print( '\n>> Finished converting the Pascal VOC dataset!')
  209. if __name__ == '__main__':
  210. if len(sys.argv) < 2:
  211. raise ValueError( '>> error. format: python *.py split_name')
  212. split = sys.argv[ 1]
  213. if split not in SPLIT_MAP:
  214. raise ValueError( '>> error. split = %s' % split)
  215. run( './VOCdevkit', 2012, split, './')


你可能感兴趣的:(计算机视觉)