图像分割实战-系列教程14:deeplabV3+ VOC分割实战2-------数据读取

图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

3、数据集读取

在第2部分我们介绍了项目用到的参数,这部分介绍怎么构建数据集

main.py -------def get_dataset()函数

VOC的数据集,2007版有公开过测试集,2012版没有公开过测试集,要进行测试需要使用它们特殊的工具进行测试。
数据集的处理,一般是指定Dataloader然后遍历取数据就行了。这里的get_dataset函数主要就是进行数据增强。

def get_dataset(opts):
    """ Dataset And Augmentation
    """
    if opts.dataset == 'voc':
        train_transform = et.ExtCompose([
            #et.ExtResize(size=opts.crop_size),
            et.ExtRandomScale((0.5, 2.0)),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

随机缩放、随机裁剪、随机翻转、转换为张量格式、转换成Tensor格式、使用指定的均值和标准差对图像进行标准化,等方式对训练集进行数据增强

        if opts.crop_val:
            val_transform = et.ExtCompose([
                et.ExtResize(opts.crop_size),
                et.ExtCenterCrop(opts.crop_size),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
        else:
            val_transform = et.ExtCompose([
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

验证集的数据增强方式稍有不同,具体取决于 opts.crop_val 的值。如果启用裁剪,它会包括尺寸调整和中心裁剪;否则,只包括张量转换和标准化

        train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='train', download=opts.download, transform=train_transform)
        val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='val', download=False, transform=val_transform)
    return train_dst, val_dst

最后创建一个用于训练的 VOC 数据集实例,使用训练变换
创建一个用于验证的 VOC 数据集实例,使用验证变换

其中VOCSegmentation是我们自定义的类,在voc.py中

4、VOCSegmentation类

注意我们的数据和标签都是长宽一样的图像

在 “ 3、数据集读取 ” 中我们解析了main.py -------def get_dataset()函数****,这个函数主要进行数据增强操作,然后调用了VOCSegmentation类进行数据集的制作,VOCSegmentation类继承torch.utils.data.Dataset

VOCSegmentation类主要分为四个部分:

  1. def __init__(self, ...):构造函数,数据从哪里去取,数据的基本介绍、基本定义,最后返回的是所有数据和所有标签的路径
  2. def __getitem__(self, index):获取单个数据项
  3. def __len__(self):返回数据集中的元素总数
  4. def decode_target(cls, target):解码或处理数据集中的标签或注释

这个 VOCSegmentation 类是数据加载和预处理的核心部分,为深度学习模型提供了所需的数据,能够与PyTorch的 DataLoader无缝集成。

4.1 init构造函数

    def __init__(self, root, year='2012', image_set='train', download=False, transform=None):

        is_aug=False
        if year=='2012_aug':
            is_aug = True
            year = '2012'

首先判断一下是否读取的是voc2012数据集

        self.root = os.path.expanduser(root)
        self.year = year
        self.url = DATASET_YEAR_DICT[year]['url']
        self.filename = DATASET_YEAR_DICT[year]['filename']
        self.md5 = DATASET_YEAR_DICT[year]['md5']
        self.transform = transform
        self.image_set = image_set
        base_dir = DATASET_YEAR_DICT[year]['base_dir']
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, 'JPEGImages')
  1. 数据集文件夹的存放路径
  2. voc数据集的年份
  3. 从定义好的字典中取出对应voc数据集年份的下载链接
  4. 从定义好…文件名
  5. 从定义好…md5哈希值
  6. 是否进行数据增强的布尔参数
  7. 数据集是训练集
  8. 从定义好…取出数据对应的路径
  9. 把数据集存放路径 连接上 数据对应的路径
  10. 再在9的路径中后面取出 JPEGImages文件夹,现在image_dir 的路径就是所有训练数据对应的路径
        if download:
            download_extract(self.url, self.root, self.filename, self.md5)
        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
  1. 是否需要下载数据集
  2. 需要,download_extract函数调用下载链接、路径、文件名、md5哈希值四个参数执行下载数据操作,download_extract函数:
    def download_extract(url, root, filename, md5):
        download_url(url, root, filename, md5)
        with tarfile.open(os.path.join(root, filename), "r") as tar:
            tar.extractall(path=root)
    
  3. 如果路径不存在,弹出报错信息
        if is_aug and image_set=='train':
            mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
            assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
            split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt'
        else:
            mask_dir = os.path.join(voc_root, 'SegmentationClass')
            splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
            split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
        if not os.path.exists(split_f):
            raise ValueError(
                'Wrong image_set entered! Please use image_set="train" '
                'or image_set="trainval" or image_set="val"')
  1. 判断是否使用了数据增强处理后的数据,如果是则:
  2. 选择数据增强处理后的数据的文件夹
  3. 检查 mask_dir 是否真的存在,如果不存在,程序将抛出一个异常,这意味着用户需要按照 README.md 的指示手动准备这个目录
  4. 设置变量 split_f 为包含增强训练数据集图像列表的文件的路径
  5. 如果不是,则:
  6. 设置不同的 mask_dirsplit_f,用于标准(非增强)数据集的训练、验证
  7. if not os.path.exists(split_f): 检查 split_f 是否存在,不存在,将抛出一个 ValueError 异常
        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]
        
        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))
  1. 从JPEGImages文件夹中,读取所有的文件名,保存所有的文件名为一个列表
  2. 按照列表的文件名给每个文件名加上前面的路径和文件后缀名
  3. 分别对数据和标签都进行这个操作
  4. 最后判断一下数据和标签的数量是否一致否则返回异常

4.2 getitem函数

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])
        if self.transform is not None:
            img, target = self.transform(img, target
        return img, target
  1. 用Pillow的Image包,结合init构造函数返回的所有的图像数据路径,读取打开图像数据
  2. 同样的方法读取打开图像标签数据
  3. 是否进行了图像增强操作
  4. 如果是,则从图像增强方法中得到图像增强处理后的数据
  5. 返回数据

4.3 len函数

    def __len__(self):
        return len(self.images)

返回样本数量

4.4 decode_target函数

    def decode_target(cls, mask):
        """decode semantic mask to RGB image"""
        return cls.cmap[mask]

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

你可能感兴趣的:(图像分割实战,计算机视觉,人工智能,语义分割,deeplab,深度学习,pytorch)