图像分割UNet (3) : 自定义数据集读取

DRIVE数据集

下载地址

  • 官网地址: https://drive.grand-challenge.org/,官网需要注册才可下载
  • 百度云链接: https://pan.baidu.com/s/1Tjkrx2B9FgoJk0KviA-rDw 密码: 8no8

数据集结构

  ── training:              训练数据集
  		  ├──1st_manual     标注图片
  		  ├──images         原始图片
  		  ├──mask           感兴趣区域
  		  
  ── test:                  测试数据集
   		  ├──1st_manual     标注图片
   		  ├──2st_manual     标注图片
  		  ├──images         原始图片
  		  ├──mask           感兴趣区域
  • mask 感兴趣区域,其中感兴趣区域像素值为255,不感兴趣区域像素值为0
  • manual 是人手工标注好的Ground Truth,将需要分割的血管都标注出来了,对我们需要检测的目标用白色显示,对目标用黑色显示

自定义数据集

自定义数据集,首先需要继承torch.utils.data下的Dataset,并且重写__init__ ,__getitem__,以及__len__这3个方法

完整代码如下:

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"
        data_root = os.path.join(root, "DRIVE", self.flag)
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."
        self.transforms = transforms
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
                       for i in img_names]
        # check files
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                         for i in img_names]
        # check files
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')
        manual = Image.open(self.manual[idx]).convert('L')
        manual = np.array(manual) / 255
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')
        roi_mask = 255 - np.array(roi_mask)
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)

        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        mask = Image.fromarray(mask)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    def __len__(self):
        return len(self.img_list)

    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fill_value=0)
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_target
        
 def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
    return batched_imgs
  • init 方法中的参数root 是数据集Drive所在的根目录;
    train 是个bool数据,如果为true的话载入training下的训练数据,否则载入test文件夹下的测试数据; transforms表示针对数据集定义的数据预处理方式
  • getitem方法,需要传入id参数,并且返回对应的数据img和标签mask ,这里的mask,和数据文件中mask是不一样的,这里的maskGround Truth .
  • 通过np.array(manual) / 255将目标设为1,背景设为0 . 在语义分割任务中, 目标的像素值从1开始标记,背景设为0
  • 通过roi_mask = 255 - np.array(roi_mask),感兴趣的区域为0,不敢兴趣的像素值为255. 这样做目的是在我们构建最终的mask的时候,让不感兴趣的区域变为255. mask = np.clip(manual + roi_mask, a_min=0, a_max=255) ,这样在计算损失的时候将像素值为255的区域直接忽略掉.

你可能感兴趣的:(图像分割,深度学习,python,计算机视觉)