deeplabv3+源码之慢慢解析 第二章datasets文件夹(3)cityscapes.py--Cityscapes类

系列文章目录(更新中)

第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数

第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数

第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–[StreamSegMetrics类和AverageMeter类]
第四章deeplabv3+源码之慢慢解析 network文件夹(0)backbone文件夹(a)hrnetv2.py–[4个类,4个函数,1个主函数]
第四章deeplabv3+源码之慢慢解析 network文件夹(0)backbone文件夹(b)mobilenetv2.py–[3个类,3个函数]
第四章deeplabv3+源码之慢慢解析 network文件夹(0)backbone文件夹©resnet.py–[2个类,12个函数]
第四章deeplabv3+源码之慢慢解析 network文件夹(0)backbone文件夹(d)xception.py–[3个类,1个函数]
第四章deeplabv3+源码之慢慢解析 network文件夹(1)_deeplab.py–[7个类和1个函数]
第四章deeplabv3+源码之慢慢解析 network文件夹(2)modeling.py–[15个函数]
第四章deeplabv3+源码之慢慢解析 network文件夹(3)utils.py–[2个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–[17个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)loss.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)scheduler.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)utils.py–[1个类,4个函数]
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)visualizer.py–[1个类]
总结

文章目录

  • 系列文章目录(更新中)
    • 说明
    • cityscapes.py导入
    • Cityscapes类


说明

  1. voc.py的代码已经说完,本节讲述cityscapes.py。
  2. 从代码上看,cityscapes数据集处理方式比voc数据集简单。
  3. cityscapes.py中仅有一个类,其中包含7个小函数,从体量和难度上,可以一次说完。
  4. 思路上和voc.py有部分类似之处,大家可以边学习边对比。

cityscapes.py导入

#导入都是基本包,和之前相比要简单。
import json
import os
from collections import namedtuple

import torch
import torch.utils.data as data
from PIL import Image
import numpy as np

Cityscapes类

提示:可以对比前一节的VOCSegmentation类,多体会这种类的处理思路,以便处理自己的数据集。

class Cityscapes(data.Dataset):
    """Cityscapes  Dataset.
    
    **Parameters:**   #详细的参数介绍,如VOCSegmentation类一样。
        - **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
        - **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
        - **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
        - **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
        - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.
    """

    # Based on https://github.com/mcordts/cityscapesScripts
    #以下时Cityscapes数据集的分类情况,共35类(包含未标注类别,实际区分34类)。此处没有和VOC数据集一样用字典,而是用元组,之后组成列表。CityscapesClass列出了各个字段对应的语义。
    CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])
    classes = [
        CityscapesClass('unlabeled',            0, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('ego vehicle',          1, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('out of roi',           3, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('static',               4, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('dynamic',              5, 255, 'void', 0, False, True, (111, 74, 0)),
        CityscapesClass('ground',               6, 255, 'void', 0, False, True, (81, 0, 81)),
        CityscapesClass('road',                 7, 0, 'flat', 1, False, False, (128, 64, 128)),
        CityscapesClass('sidewalk',             8, 1, 'flat', 1, False, False, (244, 35, 232)),
        CityscapesClass('parking',              9, 255, 'flat', 1, False, True, (250, 170, 160)),
        CityscapesClass('rail track',           10, 255, 'flat', 1, False, True, (230, 150, 140)),
        CityscapesClass('building',             11, 2, 'construction', 2, False, False, (70, 70, 70)),
        CityscapesClass('wall',                 12, 3, 'construction', 2, False, False, (102, 102, 156)),
        CityscapesClass('fence',                13, 4, 'construction', 2, False, False, (190, 153, 153)),
        CityscapesClass('guard rail',           14, 255, 'construction', 2, False, True, (180, 165, 180)),
        CityscapesClass('bridge',               15, 255, 'construction', 2, False, True, (150, 100, 100)),
        CityscapesClass('tunnel',               16, 255, 'construction', 2, False, True, (150, 120, 90)),
        CityscapesClass('pole',                 17, 5, 'object', 3, False, False, (153, 153, 153)),
        CityscapesClass('polegroup',            18, 255, 'object', 3, False, True, (153, 153, 153)),
        CityscapesClass('traffic light',        19, 6, 'object', 3, False, False, (250, 170, 30)),
        CityscapesClass('traffic sign',         20, 7, 'object', 3, False, False, (220, 220, 0)),
        CityscapesClass('vegetation',           21, 8, 'nature', 4, False, False, (107, 142, 35)),
        CityscapesClass('terrain',              22, 9, 'nature', 4, False, False, (152, 251, 152)),
        CityscapesClass('sky',                  23, 10, 'sky', 5, False, False, (70, 130, 180)),
        CityscapesClass('person',               24, 11, 'human', 6, True, False, (220, 20, 60)),
        CityscapesClass('rider',                25, 12, 'human', 6, True, False, (255, 0, 0)),
        CityscapesClass('car',                  26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
        CityscapesClass('truck',                27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
        CityscapesClass('bus',                  28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
        CityscapesClass('caravan',              29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
        CityscapesClass('trailer',              30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
        CityscapesClass('train',                31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
        CityscapesClass('motorcycle',           32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
        CityscapesClass('bicycle',              33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
        CityscapesClass('license plate',        -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
    ]

    train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)] #将不需要训练的目标,train_id设定为255即可。本段代码设定0-18,即19个类别。
    train_id_to_color.append([0, 0, 0])   #增加一个[0, 0, 0],目的在于对应没有的类别。
    train_id_to_color = np.array(train_id_to_color)
    id_to_train_id = np.array([c.train_id for c in classes])
    
    #train_id_to_color = [(0, 0, 0), (128, 64, 128), (70, 70, 70), (153, 153, 153), (107, 142, 35),
    #                      (70, 130, 180), (220, 20, 60), (0, 0, 142)]
    #train_id_to_color = np.array(train_id_to_color)
    #id_to_train_id = np.array([c.category_id for c in classes], dtype='uint8') - 1

    def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None): #对应VOCSegmentation类的第14-65行。
        self.root = os.path.expanduser(root)
        self.mode = 'gtFine'
        self.target_type = target_type
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)  #文件夹,后文补图。

        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.transform = transform

        self.split = split
        self.images = []
        self.targets = []

        if split not in ['train', 'test', 'val']:   #对split参数的提示。
            raise ValueError('Invalid split for mode! Please use split="train", split="test"'
                             ' or split="val"')

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): #确保"split" and "mode"参数存在对应的文件夹路径。
            raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                               ' specified "split" and "mode" are inside the "root" directory')
        
        for city in os.listdir(self.images_dir):   #获得对应的输入和目标。
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)

            for file_name in os.listdir(img_dir):
                self.images.append(os.path.join(img_dir, file_name))
                target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                             self._get_target_suffix(self.mode, self.target_type))
                self.targets.append(os.path.join(target_dir, target_name))

    @classmethod
    def encode_target(cls, target):    #编码函数,仅仅是返回train_id列表。
        return cls.id_to_train_id[np.array(target)]

    @classmethod
    def decode_target(cls, target):  #对应VOCSegmentation类的第86-88行。
        target[target == 255] = 19    #此代码中选0-18个类别,255即不训练的类别,定为19。
        #target = target.astype('uint8') + 1
        return cls.train_id_to_color[target]

    def __getitem__(self, index):   #按索引得到对应的输入(image)和输出(target)图片。对应VOCSegmentation类的第67-79行。
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """
        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])
        if self.transform:
            image, target = self.transform(image, target)
        target = self.encode_target(target)
        return image, target

    def __len__(self):      #返回图片(长度)数量。对应VOCSegmentation类的第82行。
        return len(self.images)

    def _load_json(self, path):    #按输入的path,打开json文件。
        with open(path, 'r') as file:
            data = json.load(file)
        return data

    def _get_target_suffix(self, mode, target_type):   #根据target_type的不同内容,得到不同的后缀。
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        elif target_type == 'polygon':
            return '{}_polygons.json'.format(mode)
        elif target_type == 'depth':
            return '{}_disparity.png'.format(mode)

Tips

1.相对VOCSegmentation类而言,Cityscapes类相对简单。
2. 补充cityscapes文件夹,可以看到对应文件夹,相对容易理解。deeplabv3+源码之慢慢解析 第二章datasets文件夹(3)cityscapes.py--Cityscapes类_第1张图片
3. 对于具体文件压缩包,split参数对应的文件夹图。
deeplabv3+源码之慢慢解析 第二章datasets文件夹(3)cityscapes.py--Cityscapes类_第2张图片
4. datasets文件夹下的两个特定数据集代码已解析完毕。下一节介绍utils.py。

你可能感兴趣的:(技术,人工智能,deeplabv3+,语义分割,深度学习)