语义分割任务中如何处理label为255的标签

语义分割常用数据集Cityscapes中会将不需要用到的像素标签设置为255,但初学者可能会遇到困惑,我们在训练或者评估的时候遇到255的标签该怎么办呢?我们需要做的是忽略。

训练计算loss时的处理

import torch
from torch import nn

class CrossEntropy2d(nn.Module):
    def __init__(self, ignore_label=255):
        super().__init__()
        self.ignore_label = ignore_label

    def forward(self, predict, target):
        """
        :param predict: [batch, num_class, height, width]
        :param target: [batch, height, width]
        :return: entropy loss
        """
        target_mask = target != self.ignore_label  # [batch, height, width]筛选出所有需要训练的像素点标签
        target = target[target_mask]  # [num_pixels]
        batch, num_class, height, width = predict.size()
        predict = predict.permute(0, 2, 3, 1)  # [batch, height, width, num_class]
        predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)
        loss = F.cross_entropy(predict, target)
        return loss	

上面代码的核心就是通过索引将需要训练的像素点拿出来进行交叉熵损失的计算

评估计算Pixel accuracy 和Mean IoU

def eval_metrics(predict, target, ignore_label=255):
    # 预处理 将ignore label对应的像素点筛除
    target_mask = (target != ignore_label)  # [batch, height, width]筛选出所有需要训练的像素点标签
    target = target[target_mask]  # [num_pixels]
    batch, num_class, height, width = predict.size()
    predict = predict.permute(0, 2, 3, 1)  # [batch, height, width, num_class]
    
    # 计算pixel accuracy
    predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)
    predict = predict.argmax(dim=1)
    num_pixels = target.numel()
    correct = (predict == target).sum()
    pixel_acc = correct / num_pixels
    
    # 计算所有类别的mIoU
    predict = predict + 1
    target = target + 1
    intersection = predict * (predict == target).long()
    area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)
    area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)
    area_label = torch.histc(target.float(), bins=num_class, max=num_class, min=1)
    mIoU = area_inter.mean() / (area_pred + area_label - area_inter).mean()
    return pixel_acc, mIoU

你可能感兴趣的:(深度学习,pytorch,语义分割)