pytorch 训练数据以及测试 全部代码(9)---deeplab v3+ 对Cityscapes数据的处理

 下面是全部的代码:

import os
import torch
import numpy as np
import scipy.misc as m
from PIL import Image
from torch.utils import data
from dataloaders.utils import recursive_glob, decode_segmap
from mypath import Path


class CityscapesSegmentation(data.Dataset):

    def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None):

        self.root = root
        self.split = split
        self.transform = transform
        self.files = {}
        self.n_classes = 19

        self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
        self.annotations_base = os.path.join(self.root, 'gtFine', self.split)

        self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')

        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]  # 16
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]  # 19
        self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
                            'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
                            'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
                            'motorcycle', 'bicycle']  # 20

        self.ignore_index = 255
        self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))

        if not self.files[split]:
            raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))

        print("Found %d %s images" % (len(self.files[split]), split))

    def __len__(self):
        return len(self.files[self.split])

    def __getitem__(self, index):

        img_path = self.files[self.split][index].rstrip()
        lbl_path = os.path.join(self.annotations_base,
                                img_path.split(os.sep)[-2],  # os.sep=='/'  get city name
                                os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')

        _img = Image.open(img_path).convert('RGB')
        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp = self.encode_segmap(_tmp)
        _target = Image.fromarray(_tmp)

        sample = {'image': _img, 'label': _target}

        if self.transform:  # to do Data transformation or Data enhancement and  convert torch
            sample = self.transform(sample)
        return sample

    def encode_segmap(self, mask):  # to change original image pixel value to 0-18 and 255 according class id
        # Put all void classes to zero
        for _voidc in self.void_classes:
            mask[mask == _voidc] = self.ignore_index  # no need class and unto set 255 (white)
        for _validc in self.valid_classes:
            mask[mask == _validc] = self.class_map[_validc]  # 19 classes encode from 0 to 18
        return mask


if __name__ == '__main__':
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt  # to show image

    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomScale((0.5, 0.75)),
        tr.RandomCrop((512, 1024)),
        tr.RandomRotate(5),
        tr.ToTensor()])

    cityscapes_train = CityscapesSegmentation(split='train',
                                transform=composed_transforms_tr)
    dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()  # from torch convert to numpy n x c x h x w
            gt = sample['label'].numpy()  # from torch convert to numpy n x c x h x w
            tmp = np.array(gt[jj]).astype(np.uint8)  # tmp.shape=c x h x w
            tmp = np.squeeze(tmp, axis=0)  # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w
            segmap = decode_segmap(tmp, dataset='cityscapes')
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)  # img_tmp=h x w x c
            plt.figure()
            plt.title('display')
            plt.subplot(211)
            plt.imshow(img_tmp)
            plt.subplot(212)
            plt.imshow(segmap)

        if ii == 1:
            break
    plt.show(block=True)

下面怎么读取图片的 可以参考:https://blog.csdn.net/zz2230633069/article/details/84640867

self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')

转换的为:

composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomScale((0.5, 0.75)),
        tr.RandomCrop((512, 1024)),
        tr.RandomRotate(5),
        tr.ToTensor()])

上面关于图像变换或者说增强的实现代码如下:

上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)

class RandomHorizontalFlip(object):
    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        return {'image': img,
                'label': mask}


class RandomScale(object):
    def __init__(self, limit):
        self.limit = limit

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        assert img.size == mask.size

        scale = random.uniform(self.limit[0], self.limit[1])
        w = int(scale * img.size[0])
        h = int(scale * img.size[1])

        img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)

        return {'image': img, 'label': mask}


class RandomCrop(object):
    def __init__(self, size, padding=0):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size # h, w
        self.padding = padding

    def __call__(self, sample):
        img, mask = sample['image'], sample['label']

        if self.padding > 0:
            img = ImageOps.expand(img, border=self.padding, fill=0)
            mask = ImageOps.expand(mask, border=self.padding, fill=0)

        assert img.size == mask.size
        w, h = img.size
        th, tw = self.size # target size
        if w == tw and h == th:
            return {'image': img,
                    'label': mask}
        if w < tw or h < th:
            img = img.resize((tw, th), Image.BILINEAR)
            mask = mask.resize((tw, th), Image.NEAREST)
            return {'image': img,
                    'label': mask}

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        img = img.crop((x1, y1, x1 + tw, y1 + th))
        mask = mask.crop((x1, y1, x1 + tw, y1 + th))

        return {'image': img,
                'label': mask}


class RandomRotate(object):
    def __init__(self, degree):
        self.degree = degree

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        rotate_degree = random.random() * 2 * self.degree - self.degree
        img = img.rotate(rotate_degree, Image.BILINEAR)
        mask = mask.rotate(rotate_degree, Image.NEAREST)

        return {'image': img,
                'label': mask}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
        mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
        mask[mask == 255] = 0   #

        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()


        return {'image': img,
                'label': mask}

直到第五个也就是最后一个(ToTensor函数)变化,对原图首先从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w x c )到(c x h x w),最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将原图转变为一个tensor可以输入后面的深度学习网络中了。

与此相对的标签图也是从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w  )增加一维得到(h x w x 1)接着调整维度到(1 x h x w),然后mask里面的数值进行处理:255.值大小的全部被重置为0,所以mask里面的值现在只有0-18这些数字了;最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将标签图转变为一个tensor可以输入后面的深度学习网络中了。

对上面的两个tensor的重新变成图像的代码如下:

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()  # from torch convert to numpy n x 3 x h x w
            gt = sample['label'].numpy()  # from torch convert to numpy n x 1 x h x w
            tmp = np.array(gt[jj]).astype(np.uint8)  # tmp.shape=1 x h x w
            tmp = np.squeeze(tmp, axis=0)  # if c=1,tmp.shape=h x w; or tmp.shape=c x h x w dimension-reduction
            segmap = decode_segmap(tmp, dataset='cityscapes') 
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)  # img_tmp=h x w x 3
            plt.figure()
            plt.title('display')
            plt.subplot(211)
            plt.imshow(img_tmp)
            plt.subplot(212)
            plt.imshow(segmap)

        if ii == 1:
            break
    plt.show(block=True)

里面的标签图(h x w)解码代码如下:

只要是同一类的就给相应的RGB数值,然后整合三张图到一张图里面

segmap = decode_segmap(tmp, dataset='cityscapes')  # tmp.shape=h x w
def decode_segmap(label_mask, dataset, plot=False):
    """Decode segmentation class labels into a color image
    Args:
        label_mask (np.ndarray): an (M,N) array of integer values denoting
          the class label at each spatial location.
        plot (bool, optional): whether to show the resulting color image
          in a figure.
    Returns:
        (np.ndarray, optional): the resulting decoded color image.
    """
    if dataset == 'pascal':
        n_classes = 21
        label_colours = get_pascal_labels()
    elif dataset == 'cityscapes':
        n_classes = 19
        label_colours = get_cityscapes_labels()
    else:
        raise NotImplementedError

    r = label_mask.copy()  # h x w
    g = label_mask.copy()  # h x w
    b = label_mask.copy()  # h x w
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # h x w x 3初始化
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    if plot:
        plt.imshow(rgb)
        plt.show()
    else:
        return rgb

下面就是label_colours的和类别对应色彩代码详情可以看cityscapes的标签颜色对照表https://blog.csdn.net/zz2230633069/article/details/84591532:

def get_cityscapes_labels():
    return np.array([
        # [  0,   0,   0],
        [128, 64, 128],
        [244, 35, 232],
        [70, 70, 70],
        [102, 102, 156],
        [190, 153, 153],
        [153, 153, 153],
        [250, 170, 30],
        [220, 220, 0],
        [107, 142, 35],
        [152, 251, 152],
        [0, 130, 180],
        [220, 20, 60],
        [255, 0, 0],
        [0, 0, 142],
        [0, 0, 70],
        [0, 60, 100],
        [0, 80, 100],
        [0, 0, 230],
        [119, 11, 32]])
def get_pascal_labels():
    """Load the mapping that associates pascal classes with label colors
    Returns:
        np.ndarray with dimensions (21, 3)
    """
    return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                       [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                       [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                       [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                       [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                       [0, 64, 128]])

 

你可能感兴趣的:(读取图片,pytorch,python,数据集)