pytorch-semseg源码解析cityscapes_loader.py

关于[::-1],[:, ::-1],[:, :, ::-1]的区别:

from PIL import Image  
import imageio
import numpy as np
img = np.arange(48).reshape(4,4,3)
print("no.1",img[::-1])
print("no.2",img[:,::-1])
print("no.3",img[:,:,::-1])

输出长这样:

no.1 [[[36 37 38]
  [39 40 41]
  [42 43 44]
  [45 46 47]]

 [[24 25 26]
  [27 28 29]
  [30 31 32]
  [33 34 35]]

 [[12 13 14]
  [15 16 17]
  [18 19 20]
  [21 22 23]]

 [[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]
  [ 9 10 11]]]
no.2 [[[ 9 10 11]
  [ 6  7  8]
  [ 3  4  5]
  [ 0  1  2]]

 [[21 22 23]
  [18 19 20]
  [15 16 17]
  [12 13 14]]

 [[33 34 35]
  [30 31 32]
  [27 28 29]
  [24 25 26]]

 [[45 46 47]
  [42 43 44]
  [39 40 41]
  [36 37 38]]]
no.3 [[[ 2  1  0]
  [ 5  4  3]
  [ 8  7  6]
  [11 10  9]]

 [[14 13 12]
  [17 16 15]
  [20 19 18]
  [23 22 21]]

 [[26 25 24]
  [29 28 27]
  [32 31 30]
  [35 34 33]]

 [[38 37 36]
  [41 40 39]
  [44 43 42]
  [47 46 45]]]

可以看出来[::-1]指的是将4*4*3中第0维即第一个4逆序读取,体现在矩阵里就是每一个4(第二个4)*3小矩阵的整体顺序被倒置了,即第一个4*3矩阵被排到了最后;

[:, ::-1]指的是第1维逆序读取,体现在矩阵里就是每一个4*3小矩阵里面,12个数的顺序被倒置了,但每个4*3矩阵本身顺序没有变化;

[:, :, ::-1]指的是第2维逆序读取,体现在矩阵里就是每一个4*3矩阵的每一行,3个数的顺序被倒置了,其他12个数或是矩阵本身顺序还是正确的。

import os
import torch
import numpy as np
import scipy.misc as m

from torch.utils import data

from ptsemseg.utils import recursive_glob
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale


class cityscapesLoader(data.Dataset):# 重写dataset类,定义初始化和获取路径以及大小等
    """cityscapesLoader
    https://www.cityscapes-dataset.com
    Data is derived from CityScapes, and can be downloaded from here:
    https://www.cityscapes-dataset.com/downloads/
    Many Thanks to @fvisin for the loader repo:
    https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
    """

    colors = [  # [  0,   0,   0],
        [128, 64, 128],
        [244, 35, 232],
        [70, 70, 70],
        [102, 102, 156],
        [190, 153, 153],
        [153, 153, 153],
        [250, 170, 30],
        [220, 220, 0],
        [107, 142, 35],
        [152, 251, 152],
        [0, 130, 180],
        [220, 20, 60],
        [255, 0, 0],
        [0, 0, 142],
        [0, 0, 70],
        [0, 60, 100],
        [0, 80, 100],
        [0, 0, 230],
        [119, 11, 32],
    ]

    label_colours = dict(zip(range(19), colors))# 将MASK的颜色和编号一一对应

    mean_rgb = {
        "pascal": [103.939, 116.779, 123.68],
        "cityscapes": [0.0, 0.0, 0.0],
    }  # pascal mean for PSPNet and ICNet pre-trained model

    def __init__(
        self,
        root,
        split="train",#确定当前使用的是train/val/test文件夹
        is_transform=False,
        img_size=(512, 1024),
        augmentations=None,
        img_norm=True,
        version="cityscapes",
        test_mode=False,
    ):
        """__init__
        :param root:
        :param split:
        :param is_transform:
        :param img_size:
        :param augmentations
        """
        self.root = root
        self.split = split
        self.is_transform = is_transform
        self.augmentations = augmentations
        self.img_norm = img_norm
        self.n_classes = 19
        self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
        self.mean = np.array(self.mean_rgb[version])
        self.files = {}

        self.images_base = os.path.join(self.root, "leftImg8bit", self.split)# 训练集文件夹路径
        self.annotations_base = os.path.join(self.root, "gtFine", self.split)# GT文件夹路径

        self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png")# 递归检索对应图片

        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]# 无效类别
        self.valid_classes = [
            7,
            8,
            11,
            12,
            13,
            17,
            19,
            20,
            21,
            22,
            23,
            24,
            25,
            26,
            27,
            28,
            31,
            32,
            33,
        ]
        self.class_names = [
            "unlabelled",
            "road",
            "sidewalk",
            "building",
            "wall",
            "fence",
            "pole",
            "traffic_light",
            "traffic_sign",
            "vegetation",
            "terrain",
            "sky",
            "person",
            "rider",
            "car",
            "truck",
            "bus",
            "train",
            "motorcycle",
            "bicycle",
        ]

        self.ignore_index = 250
        self.class_map = dict(zip(self.valid_classes, range(19)))# 有效类别生成对应编号

        if not self.files[split]:
            raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))

        print("Found %d %s images" % (len(self.files[split]), split))

    def __len__(self):
        """__len__"""
        return len(self.files[self.split])

    def __getitem__(self, index):# dataset这个类返回的就是这个函数返回的值
        """__getitem__
        :param index:
        """
        img_path = self.files[self.split][index].rstrip()
        lbl_path = os.path.join(
            self.annotations_base,
            img_path.split(os.sep)[-2],
            os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
        )

        img = m.imread(img_path)
        img = np.array(img, dtype=np.uint8)# h*w*c

        lbl = m.imread(lbl_path)
        lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))

        if self.augmentations is not None:
            img, lbl = self.augmentations(img, lbl)

        if self.is_transform:
            img, lbl = self.transform(img, lbl)

        return img, lbl# 返回了经过增强和转换的数据

    def transform(self, img, lbl):
        """transform
        :param img:
        :param lbl:
        """
        img = m.imresize(img, (self.img_size[0], self.img_size[1]))  # uint8 with RGB mode
        img = img[:, :, ::-1]  # RGB -> BGR,::-1是翻转读取,此处指将通道那一维逆序读取,方便OpenCV处理
        img = img.astype(np.float64)
        img -= self.mean
        if self.img_norm:
            # Resize scales images from 0 to 255, thus we need
            # to divide by 255.0
            img = img.astype(float) / 255.0
        # HWC -> CHW
        img = img.transpose(2, 0, 1)

        classes = np.unique(lbl)
        lbl = lbl.astype(float)
        lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
        # 按输入图片比例缩放标注图,因为后面是直接拿标注图当mask盖在原图上对比,要求大小一致
        lbl = lbl.astype(int)

        if not np.all(classes == np.unique(lbl)):# 对比标注图类别和设定类别数量是否一致
            print("WARN: resizing labels yielded fewer classes")

        if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes):
            print("after det", classes, np.unique(lbl))
            raise ValueError("Segmentation map contained invalid class values")

        img = torch.from_numpy(img).float()
        # 生成返回的tensor会和ndarry共享数据,任何对tensor的操作都会影响到ndarry
        lbl = torch.from_numpy(lbl).long()

        return img, lbl

    def decode_segmap(self, temp):# 根据类别上色
        r = temp.copy()# h*w
        g = temp.copy()
        b = temp.copy()
        for l in range(0, self.n_classes):
            r[temp == l] = self.label_colours[l][0]
            g[temp == l] = self.label_colours[l][1]
            b[temp == l] = self.label_colours[l][2]
            # r,g,b都是和输入(h*w)大小一致的图,对应类别的像素上分别只有该类别的r,g和b值

        rgb = np.zeros((temp.shape[0], temp.shape[1], 3))# 生成h*w*3的矩阵,因为mask只有RGB三个通道
        rgb[:, :, 0] = r / 255.0# 矩阵每一行的3个数(0,1,2),分别填入R,G,B的数值
        rgb[:, :, 1] = g / 255.0
        rgb[:, :, 2] = b / 255.0
        return rgb

    def encode_segmap(self, mask):# 去除无效类
        # Put all void classes to zero
        for _voidc in self.void_classes:
            mask[mask == _voidc] = self.ignore_index
        for _validc in self.valid_classes:
            mask[mask == _validc] = self.class_map[_validc]
        return mask


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)])

    local_path = "/datasets01/cityscapes/112817/"
    dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations)
    bs = 4
    trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
    for i, data_samples in enumerate(trainloader):
        imgs, labels = data_samples# 因为cityscapes对象返回了输入图像和标注图像
        import pdb

        pdb.set_trace()
        imgs = imgs.numpy()[:, ::-1, :, :]# NCHW,将BGR换回RGB
        imgs = np.transpose(imgs, [0, 2, 3, 1])# NHWC,每一列是R,G,B
        f, axarr = plt.subplots(bs, 2)
        for j in range(bs):
            axarr[j][0].imshow(imgs[j])
            axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
        plt.show()
        a = input()
        if a == "ex":
            break
        else:
            plt.close()

 

你可能感兴趣的:(图像分割代码,深度学习,python,pytorch,机器学习)