DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据

文章目录

      • 一、PASCAL VOC 2012简介
      • 二、什么是TFRecord?
      • 三、TFRecords文件的写入与读取
      • 四、参考链接

  在上一篇博文中,我们介绍如何借助【 Labelme】软件制作原始的语义分割的标注数据。那么,如何将标注生成的数据(比如掩码图)转化为训练直接可以使用的数据?本篇博文将详细如何使用 TFRecord生成【 PASCAL VOC2012】训练数据。

一、PASCAL VOC 2012简介

  该数据集是常用的语义分割数据集,总共有20类分割目标。在我之前的博文中,详细介绍了该数据集的详细内容,请参考博文:PASCAL VOC 2012 数据集解析。

二、什么是TFRecord?

  TFRecord 格式是一种用于存储二进制记录序列的简单格式。协议缓冲区是一个跨平台、跨语言的库,用于高效地序列化结构化数据。协议消息由 .proto 文件定义,这通常是了解消息类型最简单的方法。tf.Example 消息(或 protobuf)是一种灵活的消息类型,表示 {"string": value} 映射。它专为 TensorFlow 而设计,并被用于 TFX 等高级 API。

  1. tf.Example消息的数据类型
    tf.train.BytesList:可以使用的类型包括 stringbyte.
    tf.train.FloatList::可以使用的类型包括 floatdouble.
    tf.train.Int64List:可以使用的类型包括 enum,bool, int32, uint32, int64,uint64.

    类型转换函数如下:

    # The following functions can be used to convert a value to a type compatible
    # with tf.Example.
    
    def _bytes_feature(value):
      """Returns a bytes_list from a string / byte."""
      if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def _float_feature(value):
      """Returns a float_list from a float / double."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    
    def _int64_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
  2. 创建tf.Example消息
    在实践中,数据集可能来自任何地方,但是从单个观测值创建 tf.Example 消息的过程相同:(a)在每个观测结果中,需要使用上述其中一种函数,将每个值转换为包含三种兼容类型之一的 tf.train.Feature;(b)创建一个从特征名称字符串到第 1 步中生成的编码特征值的映射(字典);(c)将第 2 步中生成的映射转换为 Features 消息。


    比如,使用 NumPy 创建一个数据集,此数据集将具有 4 个特征:
      (a)具有相等 False 或 True 概率的布尔特征;
      (b)从 [0, 5] 均匀随机选择的整数特征;
      (c)通过将整数特征作为索引从字符串表生成的字符串特征;
      (d)来自标准正态分布的浮点特征;


    创建消息的实例如下代码如下,比如创建10000个样本

    # The number of observations in the dataset.
    n_observations = int(10000)
    
    # Boolean feature, encoded as False or True.
    feature0 = np.random.choice([False, True], n_observations)
    
    # Integer feature, random from 0 to 4.
    feature1 = np.random.randint(0, 5, n_observations)
    
    # String feature
    strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
    feature2 = strings[feature1]
    
    # Float feature, from a standard normal distribution
    feature3 = np.random.randn(n_observations)
    

    可以使用 _bytes_feature_float_feature_int64_feature 将下面的每个特征强制转换为兼容 tf.Example 的类型。然后,可以通过下面的已编码特征创建 tf.Example 消息:

    def serialize_example(feature0, feature1, feature2, feature3):
      """
      Creates a tf.Example message ready to be written to a file.
      """
      # Create a dictionary mapping the feature name to the tf.Example-compatible
      # data type.
      feature = {
          'feature0': _int64_feature(feature0),
          'feature1': _int64_feature(feature1),
          'feature2': _bytes_feature(feature2),
          'feature3': _float_feature(feature3),
      }
    
      # Create a Features message using tf.train.Example.
      example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
      return example_proto.SerializeToString()
    

    例如,假设我们从数据集中获得了一个观测值 [False, 4, bytes('goat'), 0.9876]。可以使用 serialize_example() 创建和打印此观测值的 tf.Example 消息。如上所述,每个观测值将被写为一条 Features 消息。请注意,tf.Example 消息只是 Features 消息外围的包装器:

    # This is an example observation from the dataset.
    example_observation = []
    
    serialized_example = serialize_example(False, 4, b'goat', 0.9876)
    print(serialized_example)
    

    打印结果如下:

    b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'
    

    要解码消息,使用 tf.train.Example.FromString 方法,代码如下:

    example_proto = tf.train.Example.FromString(serialized_example)
    print(example_proto)
    

    打印结果如下:

    features {
      feature {
        key: "feature0"
        value {
          int64_list {
            value: 0
          }
        }
      }
      feature {
        key: "feature1"
        value {
          int64_list {
            value: 4
          }
        }
      }
      feature {
        key: "feature2"
        value {
          bytes_list {
            value: "goat"
          }
        }
      }
      feature {
        key: "feature3"
        value {
          float_list {
            value: 0.9876000285148621
          }
        }
      }
    }
    
    

    上述内容的完整代码如下所示,

    import tensorflow as tf
    import numpy as np
    
    
    # The following functions can be used to convert a value to a type compatible
    # with tf.Example.
    
    def _bytes_feature(value):
    	"""Returns a bytes_list from a string / byte."""
    	if isinstance(value, type(tf.constant(0))):
    		value=value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def _float_feature(value):
    	"""Returns a float_list from a float / double."""
    	return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    
    
    def _int64_feature(value):
    	"""Returns an int64_list from a bool / enum / int / uint."""
    	return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    # The number of observations in the dataset.
    n_observations=int(1e4)
    
    # Boolean feature, encoded as False or True.
    feature0=np.random.choice([False, True], n_observations)
    
    # Integer feature, random from 0 to 4.
    feature1=np.random.randint(0, 5, n_observations)
    
    # String feature
    strings=np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
    feature2=strings[feature1]
    
    # Float feature, from a standard normal distribution
    feature3=np.random.randn(n_observations)
    
    
    def serialize_example(feature0, feature1, feature2, feature3):
    	"""
        Creates a tf.Example message ready to be written to a file.
        """
    	# Create a dictionary mapping the feature name to the tf.Example-compatible
    	# data type.
    	feature={
    		'feature0': _int64_feature(feature0),
    		'feature1': _int64_feature(feature1),
    		'feature2': _bytes_feature(feature2),
    		'feature3': _float_feature(feature3),
    	}
    
    	# Create a Features message using tf.train.Example.
    
    	example_proto=tf.train.Example(features=tf.train.Features(feature=feature))
    	return example_proto.SerializeToString()
    
    # This is an example observation from the dataset.
    
    example_observation = []
    
    serialized_example = serialize_example(False, 4, b'goat', 0.9876)
    print(serialized_example)
    example_proto = tf.train.Example.FromString(serialized_example)
    print(example_proto)
    

三、TFRecords文件的写入与读取

此次我们将官方的PASCAL VOC 2012分割数据集制作为TFRecords,以便用于训练语义分割网络(DeepLab-V3+)。官方数据下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar. 下面,我们将逐步介绍官方数据集以及如何制作【TFRecords】。

  1. PASCAL VOC 2012文件结构

    +VOCtrainval_11-May-2012
      + VOCdevkit
         + VOC2012
           + Annotations
           + ImageSets
             +Action
             +Layout
             +Main
             +Segmentation
           + JPEGImages
           + SegmentationClass
           + SegmentationObject

    Annotations:存放的XML文件(17125个文件),与JPEGImage文件夹的图片一一对应,主要记录了图片的基本信息,比如文件路径,来源,标注信息等,打开其中文件如下图所示,

    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第1张图片

    <annotation>
        <folder>VOC2012</folder>  # 图片的存放路径
        <filename>2007_000027.jpg</filename> # 图片名称
        <source>                  # 图片来源相关信息
            <database>The VOC2007 Database</database>
            <annotation>PASCAL VOC2007</annotation>
            <image>flickr</image>
        </source>
        <size>     # 图像尺寸
            <width>486</width>
            <height>500</height>
            <depth>3</depth> 
        </size>
        <segmented>0</segmented> # 是否有语义分割标注,0:表示没有,1:表示有
        <object>  # 检测的目标,如果有多个会有多个标签
            <name>person</name>      # 目标类别,人
            <pose>Unspecified</pose> # 拍摄角度
            <truncated>0</truncated> # 是否被截断, 0表示完整  
            <difficult>0</difficult> # 目标是否难以识别, 0表示容易识别
            <bndbox>  # bounding-box, 包含左上角和右下角xy坐标
                <xmin>174</xmin>
                <ymin>101</ymin>
                <xmax>349</xmax>
                <ymax>351</ymax>
            </bndbox>
            <part> 
                <name>head</name> # 头
                <bndbox>
                    <xmin>169</xmin>
                    <ymin>104</ymin>
                    <xmax>209</xmax>
                    <ymax>146</ymax>
                </bndbox>
            </part>
            <part>   
                <name>hand</name> # 手
                <bndbox>
                    <xmin>278</xmin>
                    <ymin>210</ymin>
                    <xmax>297</xmax>
                    <ymax>233</ymax>
                </bndbox>
            </part>
            <part>  
                <name>foot</name> # 脚
                <bndbox>
                    <xmin>273</xmin>
                    <ymin>333</ymin>
                    <xmax>297</xmax>
                    <ymax>354</ymax>
                </bndbox>
            </part>
            <part>
                <name>foot</name> # 脚
                <bndbox>
                    <xmin>319</xmin>
                    <ymin>307</ymin>
                    <xmax>340</xmax>
                    <ymax>326</ymax>
                </bndbox>
            </part>
        </object>
    </annotation>
    
    

    ImageSets:将各个标注任务的文件名列表和相应的标注信息存放于(txt)文本。

    ImageSets/子目录 子目录内容简述
    Action 存放的是人的动作标注信息,10种动作类包括【 jumping、phoning、playinginstrument、reading、ridingbike、ridinghorse、running、takingphoto、usingcomputer、walking】;
    Layout 存放的是具有人体部位的数据(人的head、hand、feet等);
    Main 存放的是图像物体识别的数据标签文本,总共20类,包含了20个分类的,class_train.txt(存放训练集的图片编号,每一个class的train数据都有5717个)、 class_val.txt(存放验证集的图片编号,每一个class的val数据都有5832个) 和 class_trainval.txt(存放以上两者的完全合并集,每一个class有5717+5832=11540个)
    Segmentation 存放语义分割数据集的文件名列表,存放与三个文本(train.txt,1464个条数据,trainval.txt,2913条数据,val.txt,1449条数据)。

    JPEGImages:存放所有的原始图片,格式为JPG,总共17125个文件。
    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第2张图片


    SegmentationClass:用于语义分割的图片,2913个文件(训练集+验证集)。
    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第3张图片


    SegmentationObject:用于目标检测的图片,2913个文件(训练集+验证集)。

    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第4张图片

  2. 移除标签颜色图
    我们知道,官方的标签数据是单通道的彩色图,在制作TFRecords之前,需要移除。

    关于颜色图的简单解释:该图的存储为【MxN,index matrix】和【Cx3,colormap matrix】的两个矩阵,索引值其实是在取颜色。下面的Matlab代码可以运行看下具体效果,

    [img, cmap] = imread('2007_000033.png');
    if isempty(cmap)
    	 % Process data as a grayscale or RGB image
    	disp('Process data as a grayscale or RGB image');
    else
    	% Process data as an indexed image
    	disp('Process data as an indexed image');
    end
    imshow(img, cmap);
    imwrite(img, cmap, 'outfile.png');
    

    我们选取 PASCAL VOC 22012的一张标签图为例【2007_000033.png】,用上述的Matlab代码运行,得到颜色矩阵和图像的索引矩阵:

    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第5张图片DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第6张图片

  3. 制作TFRecords
    在DeepLab的开源代码中,我们需要先移除标签的颜色,tensorflow代码如下:

    
    """Removes the color map from segmentation annotations.
    
    Removes the color map from the ground truth segmentation annotations and save
    the results to output_dir.
    """
    import glob
    import os.path
    import numpy as np
    
    from PIL import Image
    
    import tensorflow as tf
    
    FLAGS = tf.compat.v1.flags.FLAGS
    
    tf.compat.v1.flags.DEFINE_string('original_gt_folder',
                                     './VOCdevkit/VOC2012/SegmentationClass',
                                     'Original ground truth annotations.')
    
    tf.compat.v1.flags.DEFINE_string('segmentation_format', 'png', 'Segmentation format.')
    
    tf.compat.v1.flags.DEFINE_string('output_dir',
                                     './VOCdevkit/VOC2012/SegmentationClassNoColor',
                                     'folder to save modified ground truth annotations.')
    
    
    def _remove_colormap(filename):
        """Removes the color map from the annotation.
    
      Args:
        filename: Ground truth annotation filename.
    
      Returns:
        Annotation without color map.
      """
        return np.array(Image.open(filename))
    
    
    def _save_annotation(annotation, filename):
        """Saves the annotation as png file.
    
      Args:
        annotation: Segmentation annotation.
        filename: Output filename.
      """
        pil_image = Image.fromarray(annotation.astype(dtype=np.uint8))
        with tf.io.gfile.GFile(filename, mode='w') as f:
            pil_image.save(f, 'PNG')
    
    
    def main(unused_argv):
        # Create the output directory if not exists.
        if not tf.io.gfile.isdir(FLAGS.output_dir):
            tf.io.gfile.makedirs(FLAGS.output_dir)
    
        annotations = glob.glob(os.path.join(FLAGS.original_gt_folder, '*.' + FLAGS.segmentation_format))
        for annotation in annotations:
            raw_annotation = _remove_colormap(annotation)
            filename = os.path.basename(annotation)[:-4]
            _save_annotation(raw_annotation,
                             os.path.join(
                                 FLAGS.output_dir,
                                 filename + '.' + FLAGS.segmentation_format))
    
    
    if __name__ == '__main__':
        tf.compat.v1.app.run()
    
    

    运行上述代码,指定移除颜色的新标签保存路径(SegmentationClassNoColor),左图是原图标签图,右图是移除颜色的结果,图片内的红色框分别是原始标签的路径以及移除结果保存的路径,具体见下图展示:
    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第7张图片 DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第8张图片


    参考tensorflow 官方开源的DeepLab代码(build_data.pybuild_voc2012_data.py),代码的流程与第二部分(二、什么是TFRecords?)的流程一致。简单来说,首先,定义好【Example Proto】域的属性;其次,转换每一个属性的格式,使其与【tf.Example】兼容;最后,将每一个【Example】序列化,写入tesorflow定义的文件中。以下是build_voc2012_data.py代码,

    """Converts PASCAL VOC 2012 data to TFRecord file format with Example protos.
    
    PASCAL VOC 2012 dataset is expected to have the following directory structure:
    
      + pascal_voc_seg
        - build_data.py
        - build_voc2012_data.py (current working directory).
        + VOCdevkit
          + VOC2012
            + JPEGImages
            + SegmentationClass
            + ImageSets
              + Segmentation
        + tfrecord
    
    Image folder:
      ./VOCdevkit/VOC2012/JPEGImages
    
    Semantic segmentation annotations:
      ./VOCdevkit/VOC2012/SegmentationClass
    
    list folder:
      ./VOCdevkit/VOC2012/ImageSets/Segmentation
    
    This script converts data into sharded data files and save at tfrecord folder.
    
    The Example proto contains the following fields:
      image/encoded: encoded image content.
      image/filename: image filename.
      image/format: image file format.
      image/height: image height.
      image/width: image width.
      image/channels: image channels.
      image/segmentation/class/encoded: encoded semantic segmentation content.
      image/segmentation/class/format: semantic segmentation file format.
    """
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import math
    import os.path
    import sys
    import build_data
    from six.moves import range
    import tensorflow as tf
    
    FLAGS = tf.app.flags.FLAGS
    
    tf.app.flags.DEFINE_string('image_folder',
                               './pascal_voc_seg/VOCdevkit/VOC2012/JPEGImages',
                               'Folder containing images.')
    
    tf.app.flags.DEFINE_string('semantic_segmentation_folder',
                               './pascal_voc_seg/VOCdevkit/VOC2012/SegmentationClassAug/SegmentationClassAug',
                               'Folder containing semantic segmentation annotations.')
    
    tf.app.flags.DEFINE_string('list_folder',
                               './pascal_voc_seg/VOCdevkit/VOC2012/ImageSets/Segmentation',
                               'Folder containing lists for training and validation')
    
    tf.app.flags.DEFINE_string('output_dir',
                               './pascal_voc_seg/tfrecord',
                               'Path to save converted SSTable of TensorFlow examples.')
    
    _NUM_SHARDS = 4
    
    
    def _convert_dataset(dataset_split):
        """Converts the specified dataset split to TFRecord format.
    
        Args:
            dataset_split: The dataset split (e.g., train, test).
    
        Raises:
            RuntimeError: If loaded image and label have different shape.
        """
        dataset = os.path.basename(dataset_split)[:-4]
        sys.stdout.write('Processing ' + dataset)
        filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
        num_images = len(filenames)
        print(num_images)
    
        num_per_shard = int(math.ceil(num_images / _NUM_SHARDS))
    
        image_reader = build_data.ImageReader('jpeg', channels=3)
        label_reader = build_data.ImageReader('png', channels=1)
    
        for shard_id in range(_NUM_SHARDS):
            output_filename = os.path.join(FLAGS.output_dir, '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, _NUM_SHARDS))
    
            with tf.io.TFRecordWriter(output_filename) as tfrecord_writer:
                start_idx = shard_id * num_per_shard
                end_idx = min((shard_id + 1) * num_per_shard, num_images)
                for i in range(start_idx, end_idx):
                    sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i + 1, len(filenames), shard_id))
                    sys.stdout.flush()
    
                    # Read the image.
                    image_filename = os.path.join(FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
                    print('image_format: ', FLAGS.image_format)
                    print('image_filename: ', image_filename)
                    image_data = tf.io.gfile.GFile(image_filename, 'rb').read()
                    height, width = image_reader.read_image_dims(image_data)
    
                    # Read the semantic segmentation annotation.
                    seg_filename = os.path.join(FLAGS.semantic_segmentation_folder, filenames[i] + '.' + FLAGS.label_format)
                    seg_data = tf.io.gfile.GFile(seg_filename, 'rb').read()
                    seg_height, seg_width = label_reader.read_image_dims(seg_data)
                    if height != seg_height or width != seg_width:
                        raise RuntimeError('Shape mismatched between image and label.')
    
                    # Convert to tf example.
                    example = build_data.image_seg_to_tfexample(image_data, filenames[i], height, width, seg_data)
                    tfrecord_writer.write(example.SerializeToString())
    
            sys.stdout.write('\n')
            sys.stdout.flush()
    
    
    def main(unused_argv):
        dataset_splits = tf.gfile.Glob(os.path.join(FLAGS.list_folder, '*.txt'))
        print(dataset_splits)
    
        for dataset_split in dataset_splits:
            _convert_dataset(dataset_split)
    
    
    if __name__ == '__main__':
        tf.compat.v1.app.run()
    
    

    以下为build_data.py代码,

    
    """Contains common utility functions and classes for building dataset.
    
    This script contains utility functions and classes to converts dataset to
    TFRecord file format with Example protos.
    
    The Example proto contains the following fields:
    
      image/encoded: encoded image content.
      image/filename: image filename.
      image/format: image file format.
      image/height: image height.
      image/width: image width.
      image/channels: image channels.
      image/segmentation/class/encoded: encoded semantic segmentation content.
      image/segmentation/class/format: semantic segmentation file format.
    """
    import collections
    import six
    import tensorflow as tf
    
    FLAGS = tf.app.flags.FLAGS
    
    tf.app.flags.DEFINE_enum('image_format', 'jpg', ['jpg', 'jpeg', 'png'], 'Image format.')
    
    tf.app.flags.DEFINE_enum('label_format', 'png', ['png'], 'Segmentation label format.')
    
    # A map from image format to expected data format.
    _IMAGE_FORMAT_MAP = {
        'jpg': 'jpeg',
        'jpeg': 'jpeg',
        'png': 'png',
    }
    
    
    class ImageReader(object):
        """Helper class that provides TensorFlow image coding utilities."""
    
        def __init__(self, image_format='jpeg', channels=3):
            """Class constructor.
    
            Args:
                image_format: Image format. Only 'jpeg', 'jpg', or 'png' are supported.
                channels: Image channels.
            """
            with tf.Graph().as_default():
                self._decode_data = tf.placeholder(dtype=tf.string)
                self._image_format = image_format
                self._session = tf.Session()
                if self._image_format in ('jpeg', 'jpg'):
                    self._decode = tf.image.decode_jpeg(self._decode_data,
                                                        channels=channels)
                elif self._image_format == 'png':
                    self._decode = tf.image.decode_png(self._decode_data,
                                                       channels=channels)
    
        def read_image_dims(self, image_data):
            """Reads the image dimensions.
    
            Args:
                image_data: string of image data.
    
            Returns:
                image_height and image_width.
            """
            image = self.decode_image(image_data)
            return image.shape[:2]
    
        def decode_image(self, image_data):
            """Decodes the image data string.
    
            Args:
                image_data: string of image data.
    
            Returns:
                Decoded image data.
    
            Raises:
                ValueError: Value of image channels not supported.
            """
            image = self._session.run(self._decode,
                                      feed_dict={self._decode_data: image_data})
            if len(image.shape) != 3 or image.shape[2] not in (1, 3):
                raise ValueError('The image channels not supported.')
    
            return image
    
    
    def _int64_list_feature(values):
        """Returns a TF-Feature of int64_list.
    
        Args:
            values: A scalar or list of values.
    
        Returns:
            A TF-Feature.
        """
        if not isinstance(values, collections.Iterable):
            values = [values]
    
        return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
    
    
    def _bytes_list_feature(values):
        """Returns a TF-Feature of bytes.
    
        Args:
            values: A string.
    
        Returns:
            A TF-Feature.
        """
    
        def norm2bytes(value):
            return value.encode() if isinstance(value, str) and six.PY3 else value
    
        return tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
    
    
    def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
        """Converts one image/segmentation pair to tf example.
    
        Args:
            image_data: string of image data.
            filename: image filename.
            height: image height.
            width: image width.
            seg_data: string of semantic segmentation data.
    
        Returns:
            tf example of one image/segmentation pair.
        """
        return tf.train.Example(features=tf.train.Features(feature={
            'image/encoded': _bytes_list_feature(image_data),
            'image/filename': _bytes_list_feature(filename),
            'image/format': _bytes_list_feature(_IMAGE_FORMAT_MAP[FLAGS.image_format]),
            'image/height': _int64_list_feature(height),
            'image/width': _int64_list_feature(width),
            'image/channels': _int64_list_feature(3),
            'image/segmentation/class/encoded': (_bytes_list_feature(seg_data)),
            'image/segmentation/class/format': _bytes_list_feature(FLAGS.label_format),
        }))
    
    

    运行上述代码,注意路径的设置,得到如下TFRecords文件:

    DeepLabV3+(tensorflow)工程应用系列(二)—— TFRecord生成VOC2012训练数据_第9张图片

  4. 四、参考链接

    1. 官方【TFRecord】中文教程(五星推荐)
    2. https://blog.csdn.net/qq_39938666/article/details/89511383
    3. https://blog.csdn.net/wenxueliu/article/details/80327316
    4. https://cloud.tencent.com/developer/article/1486521
    5. Index image and colormap的区别与联系
    6. https://www.it1352.com/1601798.html

    你可能感兴趣的:(DeepLabV3+工程应用,TFRecords制作,DeepLab-V3+,语义分割标签图,Index,Image,colormap)