pytorch-deeplab-xception项目问题-混淆矩阵和softmax特征类别设置,数据导入中心裁剪-保存分割图到TensorboardSummary

关于训练,测试时候,类别选择问题,首先是网络进入softmax之前的特征类别数目:

#Define model
model = DeepLab(num_classes=self.nclass,
                backbone=args.backbone,
                output_stride=args.out_stride,
                sync_bn=False,
                freeze_bn=False)

接着是混淆矩阵类别数目:

这两个都是同一个位置来指定的:

在pytorch-deeplab-xception/dataloaders/__init__.py

def make_data_loader(args, **kwargs):

    if args.dataset == 'pascal':
        train_set = pascal.VOCSegmentation(args, split='train')
        val_set = pascal.VOCSegmentation(args, split='val')
        if args.use_sbd:
            sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
            train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])

        num_class = train_set.NUM_CLASSES
        #num_class = 2
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'cityscapes':
        train_set = cityscapes.CityscapesSegmentation(args, split='train')
        val_set = cityscapes.CityscapesSegmentation(args, split='val')
        test_set = cityscapes.CityscapesSegmentation(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'coco':
        train_set = coco.COCOSegmentation(args, split='train')
        val_set = coco.COCOSegmentation(args, split='val')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError

这里的num_class = train_set.NUM_CLASSES 指定了混淆矩阵的类别数目

如果是自己的类别,一定要修改:

num_class=(class+1)

我自己的数据集在下面训练正常,但是:

我用自己的COCO数据训练的时候,发现错误:

RuntimeError: CUDA error: device-side assert triggered void 

使用pytorch的时候报这个错误说明你label中有些指不在[0, num classes), 区间左闭右开。比如类别数num_class=3, 你的label出现了-1或者3, 4, 5等!!!!

通过排查我发现了,我的类别竟然有15.明明我只设置了num_class=2

后来发现:

pytorch-deeplab-xception/dataloaders/datasets/coco.py

作者不想分类太多,只想分类一样的类别数目 21类,所以作者做了一个

索引对应

CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
     1, 64, 20, 63, 7, 72]

通过函数去剔除不满足这些类别的COCO数据:

    def _gen_seg_mask(self, target, h, w):
        mask = np.zeros((h, w), dtype=np.uint8)
        coco_mask = self.coco_mask
        for instance in target:
            rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
            m = coco_mask.decode(rle)
            cat = instance['category_id']
            if cat in self.CAT_LIST:
                c = self.CAT_LIST.index(cat)
            else:
                continue
            if len(m.shape) < 3:
                mask[:, :] += (mask == 0) * (m * c)
            else:
                mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
        return mask

 

VOC:

pytorch-deeplab-xception项目问题-混淆矩阵和softmax特征类别设置,数据导入中心裁剪-保存分割图到TensorboardSummary_第1张图片

COCO:

person bicycle car motorbike aeroplane bus train
truck boat traffic_light fire_hydrant stop_sign parking_meter bench
bird cat dog horse sheep cow elephant
bear zebra giraffe backpack umbrella handbag tie
suitcase frisbee skis snowboard sports_ball kite baseball_bat 
baseball_glove skateboard surfboard tennis_racket bottle wine_glass cup
fork knife spoon bowl banana apple sandwich
orange broccoli carrot hot_dog pizza donut cake
chair sofa pottedplant bed diningtable toilet tvmonitor 
laptop mouse remote keyboard cell_phone microwave oven 
toaster sink refrigerator book clock vase scissors 
teddy_bear hair_drier toothbrush

 pytorch-deeplab-xception项目问题-混淆矩阵和softmax特征类别设置,数据导入中心裁剪-保存分割图到TensorboardSummary_第2张图片pytorch-deeplab-xception项目问题-混淆矩阵和softmax特征类别设置,数据导入中心裁剪-保存分割图到TensorboardSummary_第3张图片

这里如果我们用的是自己的数据集,需要进行改进,比如:

我的只有一类:

CAT_LIST = [0, 1]

还有一个问题,就是通过:Tensorflow去保存分割结果:

pytorch-deeplab-xception/train.py 

from utils.summaries import TensorboardSummary

self.summary = TensorboardSummary(self.saver.experiment_dir)
self.writer = self.summary.create_summary()
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

 

        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)

 通过pytorch-deeplab-xception/utils/summaries.py实现:

import os
import torch
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
from dataloaders.utils import decode_seg_map_sequence

class TensorboardSummary(object):
    def __init__(self, directory):
        self.directory = directory

    def create_summary(self):
        writer = SummaryWriter(log_dir=os.path.join(self.directory))
        return writer

    def visualize_image(self, writer, dataset, image, target, output, global_step):
        grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
        writer.add_image('Image', grid_image, global_step)
        grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(),
                                                       dataset=dataset), 3, normalize=False, range=(0, 255))
        writer.add_image('Predicted label', grid_image, global_step)
        grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(),
                                                       dataset=dataset), 3, normalize=False, range=(0, 255))
        writer.add_image('Groundtruth label', grid_image, global_step)

该可视化类调用了pytorch-deeplab-xception/dataloaders/utils.py

import matplotlib.pyplot as plt
import numpy as np
import torch

def decode_seg_map_sequence(label_masks, dataset='pascal'):
    rgb_masks = []
    for label_mask in label_masks:
        rgb_mask = decode_segmap(label_mask, dataset)
        rgb_masks.append(rgb_mask)
    rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
    return rgb_masks


def decode_segmap(label_mask, dataset, plot=False):
    """Decode segmentation class labels into a color image
    Args:
        label_mask (np.ndarray): an (M,N) array of integer values denoting
          the class label at each spatial location.
        plot (bool, optional): whether to show the resulting color image
          in a figure.
    Returns:
        (np.ndarray, optional): the resulting decoded color image.
    """
    if dataset == 'pascal' or dataset == 'coco':
        n_classes = 21
        label_colours = get_pascal_labels()
    elif dataset == 'cityscapes':
        n_classes = 19
        label_colours = get_cityscapes_labels()
    else:
        raise NotImplementedError

    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    if plot:
        plt.imshow(rgb)
        plt.show()
    else:
        return rgb


def encode_segmap(mask):
    """Encode segmentation label images as pascal classes
    Args:
        mask (np.ndarray): raw segmentation label image of dimension
          (M, N, 3), in which the Pascal classes are encoded as colours.
    Returns:
        (np.ndarray): class map with dimensions (M,N), where the value at
        a given location is the integer denoting the class index.
    """
    mask = mask.astype(int)
    label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
    for ii, label in enumerate(get_pascal_labels()):
        label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
    label_mask = label_mask.astype(int)
    return label_mask


def get_cityscapes_labels():
    return np.array([
        [128, 64, 128],
        [244, 35, 232],
        [70, 70, 70],
        [102, 102, 156],
        [190, 153, 153],
        [153, 153, 153],
        [250, 170, 30],
        [220, 220, 0],
        [107, 142, 35],
        [152, 251, 152],
        [0, 130, 180],
        [220, 20, 60],
        [255, 0, 0],
        [0, 0, 142],
        [0, 0, 70],
        [0, 60, 100],
        [0, 80, 100],
        [0, 0, 230],
        [119, 11, 32]])


def get_pascal_labels():
    """Load the mapping that associates pascal classes with label colors
    Returns:
        np.ndarray with dimensions (21, 3)
    """
    return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                       [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                       [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                       [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                       [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                       [0, 64, 128]])

PDF下载:

http://web.eng.tau.ac.il/deep_learn/

http://web.eng.tau.ac.il/deep_learn/wp-content/uploads/2017/12/Rethinking-Atrous-Convolution-for-Semantic-Image-Segmentation-1.pdf

pytorch-deeplab-xception代码解析

https://www.jianshu.com/p/026c5d78d3b1

这个项目训练时候会有一个选项,是否中心裁剪:

    def transform_val(self, sample):
        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
            ])
        return composed_transforms(sample)

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])
        return composed_transforms(sample)
class FixScaleCrop(object):
    def __init__(self, crop_size):
        self.crop_size = crop_size

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        w, h = img.size
        if w > h:
            oh = self.crop_size
            ow = int(1.0 * w * oh / h)
        else:
            ow = self.crop_size
            oh = int(1.0 * h * ow / w)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # center crop
        w, h = img.size
        x1 = int(round((w - self.crop_size) / 2.))
        y1 = int(round((h - self.crop_size) / 2.))
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        return {'image': img,
                'label': mask}

class FixedResize(object):
    def __init__(self, size):
        self.size = (size, size)  # size: (h, w)

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']

        assert img.size == mask.size

        img = img.resize(self.size, Image.BILINEAR)
        mask = mask.resize(self.size, Image.NEAREST)

        return {'image': img,
                'label': mask}

 如果训练集4000*2000的图像,而训练大小是513*513,那么会把分割标注的地方中心裁剪拿去训练

而测试的时候,如果拿4000*2000的图像resize到513*513再放入网络,就效果不一定好

 

test测试:

https://github.com/jfzhang95/pytorch-deeplab-xception/issues/122

    def save_image(self, array, id, op, oriimg=None):
        text = 'gt'
        if op == 0:
            text = 'pred'
        file_name = str(id)+'_'+text+'.png'

        ori_name = str(id)+'_'+'vis'+'.png'
        #513*513
        r = array.copy()
        g = array.copy()
        b = array.copy()

        if oriimg is True:
            oneimgpath = str(id) + '.jpg'
            from mypath import Path
            #JPEGImages_gray
            oneimg = Image.open(os.path.join(Path.db_root_dir('coco'), "images/test2017", oneimgpath)).convert('RGB')

            crop_size = 513
            w, h = oneimg.size
            if w > h:
                oh = crop_size
                ow = int(1.0 * w * oh / h)
            else:
                ow = crop_size
                oh = int(1.0 * h * ow / w)
            img = oneimg.resize((ow, oh), Image.BILINEAR)
            # center crop
            w, h = img.size
            x1 = int(round((w - crop_size) / 2.))
            y1 = int(round((h - crop_size) / 2.))
            oneimg = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))

        for i in range(self.nclass):
            r[array == i] = self.color_map[i][0]
            g[array == i] = self.color_map[i][1]
            b[array == i] = self.color_map[i][2]

        #513*513*3
        rgb = np.dstack((r, g, b))
        hh,ww,_ = rgb.shape

        if oriimg is True:
            oneimg = oneimg.resize((hh, ww), Image.ANTIALIAS)

        save_img = Image.fromarray(rgb.astype('uint8'))
        save_img.save(self.args.save_path+os.sep+file_name)


        if oriimg is True:
            oneimg = np.array(oneimg)
            for i in range(self.nclass):
                if i != 0:
                    index = np.argwhere(array == i)
                    for key in index:
                        oneimg[key[0]][key[1]][0] = self.color_map[i][0]
                        oneimg[key[0]][key[1]][1] = self.color_map[i][1]
                        oneimg[key[0]][key[1]][2] = self.color_map[i][2]
            oneimg = Image.fromarray(oneimg.astype('uint8'))
            oneimg.save(self.args.save_path+os.sep+ori_name, quality=100)

使用上述代码,我发现一个问题,那就是id有可能不对应,所以换了一个方法:

    def test(self):
        self.model.eval()
        self.evaluator.reset()
        # tbar = tqdm(self.test_loader, desc='\r')
        num = len(self.test_loader)
        for i, sample in enumerate(self.test_loader):
            image, target = sample['image'], sample['label']
            print(i,"/",num)
            torch.cuda.synchronize()
            start = time.time()
            with torch.no_grad():
                output = self.model(image)
            end = time.time()
            times = (end - start) * 1000
            print(times, "ms")
            torch.cuda.synchronize()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()

            image1 = image.data.cpu().numpy()
            # #target1 = target.cpu().numpy()
            image1 = image1[0, :]
            target1 = target[0, :]
            # #image1.reshape([image1.size[1],image1.size[2],image1.size[3]])
            # #target1.reshape([image1.size[1],image1.size[2],image1.size[3]])
            image1 = image1.transpose(1,2,0)
            # #target1 = target1.transpose(2,1,0)
            # import cv2
            # image1 = cv2.cvtColor(image1, cv2.COLOR_RGB2BGR)
            # import cv2
            # cv2.imwrite("./image1.jpg",image1)
            cv2.imwrite("./target111.jpg", target1)

            pred = np.argmax(pred, axis=1)


            self.save_image(pred[0], i, 0, True, sample['ori_image'])
            self.save_image(target[0], i, 1, None, sample['ori_image'])
            self.evaluator.add_batch(target, pred)
    
    def save_image(self, array, id, op, oriimg=None, image111=None):
        import cv2
        text = 'gt'
        if op == 0:
            text = 'pred'
        file_name = str(id)+'_'+text+'.png'

        drow_ori_name = str(id)+'_'+'vis'+'.png'

        #513*513
        r = array.copy()
        g = array.copy()
        b = array.copy()

        if oriimg is True:
            oneimgpath = str(id) + '.jpg'
            from mypath import Path
            #JPEGImages_gray
            image111 = image111.data.cpu().numpy()
            image111 = image111[0, :]
            image111 = image111.transpose(1,2,0)
            oneimg = image111

        for i in range(self.nclass):
            r[array == i] = self.color_map[i][2]
            g[array == i] = self.color_map[i][1]
            b[array == i] = self.color_map[i][0]

        #513*513*3
        rgb = np.dstack((r, g, b))
        hh,ww,_ = rgb.shape

        #if oriimg is True:
            #oneimg = oneimg.resize((hh, ww), Image.ANTIALIAS)
            # 原图
            #image1 = cv2.cvtColor(oneimg, cv2.COLOR_RGB2BGR)
            #oneimg.save(self.args.save_path + os.sep + ori_name, quality=100)
            #cv2.imwrite(self.args.save_path + os.sep + ori_name, image1)


        #----gt ---- pred
        cv2.imwrite(self.args.save_path+os.sep+file_name, rgb)
        #save_img = Image.fromarray(rgb.astype('uint8'))
        # pred
        #save_img.save(self.args.save_path+os.sep+file_name, quality=100)

        #oneimg = oneimg.transpose(2, 0, 1)
        if oriimg is True:
            #oneimg = np.array(oneimg)
            for i in range(self.nclass):
                if i != 0:
                    index = np.argwhere(array == i)
                    for key in index:
                        oneimg[key[0]][key[1]][0] = oneimg[key[0]][key[1]][0] * 0.5 + self.color_map[i][0] * 0.5
                        oneimg[key[0]][key[1]][1] = oneimg[key[0]][key[1]][1] * 0.5 + self.color_map[i][1] * 0.5
                        oneimg[key[0]][key[1]][2] = oneimg[key[0]][key[1]][2] * 0.5 + self.color_map[i][2] * 0.5

                        #img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
            #oneimg = Image.fromarray(oneimg.astype('uint8'))
            #可视化
            oneimg = cv2.cvtColor(oneimg, cv2.COLOR_RGB2BGR)
            #oneimg.save(self.args.save_path + os.sep + ori_name, quality=100)
            cv2.imwrite(self.args.save_path + os.sep + drow_ori_name, oneimg)
            #oneimg.save(self.args.save_path+os.sep+drow_ori_name, quality=100)

 

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        for m in self.model.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()
        tbar = tqdm(self.train_loader)


or

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()

        if self.args.freeze_bn:
            for m in self.model.modules():
                if isinstance(m, SynchronizedBatchNorm2d):
                    m.eval()
                elif isinstance(m, nn.BatchNorm2d):
                    m.eval()

        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
 def make_data_loader(args, **kwargs):

    if args.dataset == 'pascal':
        train_set = pascal.VOCSegmentation(args, split='train')
        val_set = pascal.VOCSegmentation(args, split='val')
        if args.use_sbd:
            sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
            train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])

        num_class = train_set.NUM_CLASSES
        num_class = (you.nums+1)

 

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

--resume ./run/pascal/deeplab-resnet/experiment_1/checkpoint.pth.tar  

https://blog.csdn.net/goodxin_ie/article/details/89645358

/root/train/results/ynh_copy/anaconda3_py3.6_torch0.4.1/lib/python3.6/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='elementwise_mean' instead.

UserWarning: size_average and reduce args will be deprecated, please use reduction='elementwise_mean' instead. 

cross_entropy(input, target, weight=None,

size_average=None, ignore_index=-100,

reduce=None,

reduction='elementwise_mean')

  • size_average(该参数不建议使用,后续版本可能被废弃),该参数指定loss是否在一个Batch内平均,即是否除以N。默认为True
  • reduce (该参数不建议使用,后续版本可能会废弃),首先说明该参数与size_average冲突,当该参数指定为False时size_average不生效,该参数默认为True。reduce为False时,对batch内的每个样本单独计算loss,loss的返回值Shape为[N],每一个数对应一个样本的loss。reduce为True时,根据size_average决定对N个样本的loss进行求和还是平均,此时返回的loss是一个数。
  • reduction 该参数在新版本中是为了取代size_average和reduce参数的。它共有三种选项'elementwise_mean','sum'和'none'。'elementwise_mean'为默认情况,表明对N个样本的loss进行求平均之后返回(相当于reduce=True,size_average=True);'sum'指对n个样本的loss求和(相当于reduce=True,size_average=False);'none'表示直接返回n分样本的loss(相当于reduce=False)

reduction = 'elementwise_mean'

pytorch-deeplab-xception/utils/loss.py:

criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index,
                                size_average=self.size_average)
criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index,
                                reduction = 'elementwise_mean')

pytorch自带的deeplabv3

import torch
#https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/
#torch.hub.set_dir("D:\\code\\python\\deeplabv3\\v3temp")
#model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
import segmentation_model as segmod
model = segmod.deeplabv3_resnet101(pretrained=True, num_classes=21)
#model = segmod.fcn_resnet101(pretrained=True, num_classes=21)
model.eval()

from PIL import Image
from torchvision import transforms
input_image = Image.open("./dog.jpg")
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)['out'][0]
output_predictions = output.argmax(0)

palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)
r.save("./draw_dog.png", quality=100)
#import matplotlib.pyplot as plt
#plt.imshow(r)
#plt.show()

 

你可能感兴趣的:(深度学习)