在训练时候loss出现负值,就立马停下来分析一下原因在哪。最有可能是损失函数出现问题,开始只使用交叉熵损失时没有出现过,在加上了dice loss时就出现了问题。于是就去dice loss中寻找原因。
1:首先需要明白语义分割的GT,每一个像素点的值就是像素的类别。
# -*- coding: utf-8 -*-
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
img = Image.open('C:/Users/翰墨大人/Desktop/0003_lable.png') #图像所在位置
img1 = np.array(img)
img1 = torch.from_numpy(img1).type(torch.FloatTensor)
# trans = transforms.ToTensor()
# img1 = trans(img1)
a = torch.unique(img1) # 查看图片内的像素值
print(a)
print(img.mode) # 查看图片模式
打印结果:
tensor([ 0., 1., 5., 7., 8., 26., 29., 38., 40.])
P
原图一共有四十个类别,在003_lable这张GT上只出现了上述的类别。剩下的像素点和上述像素点是重复的。
注意:
这里将np格式转换为tensor格式时候,不能使用transforms.ToTensor(),他会将像素值发生改变。这样新的像素值点和类别就不是一一对应的。
tensor([0.0000, 0.0039, 0.0196, 0.0275, 0.0314, 0.1020, 0.1137, 0.1490, 0.1569])
P
2:语义分割的GT一共有四十类,没有通道,而在模型中pred的输出为一个通道。
在计算二分类dice loss时候,首先要将pred进行sigmoid。GT是四十个类别,如果是二分类的话那么标签就必须是[0,1]。而loss为负值的原因就是没有将标签转换为[0,1]。
X是pred,Y是GT,分子就是X和Y进行矩阵的相乘再相加,当Y中含有大于1的类别,比如30,40的话,而X又是进过sigmoid之后再(0,1)之内,那么分子除以分母的值就会大于1,造成dice loss就变成了负值。
如何将GT变为[0,1]呢?因为我需要的是对GT进行提边,使用一个拉普拉斯对GT进行卷积,然后再使用一个阈值,大于阈值为1,小于为0。为是边缘和不是边缘。GT就又灰度图像转换为了二值图像。
l = F.conv2d(x,sobel,padding=1,stride=1)
print(l.shape)
ll = torch.unique(l) # 查看图片内的像素值
print(ll)
l[l>0.1]=1
l[l<0.1]=0
l_ = torch.unique(l) # 查看图片内的像素值
print(l_)
结果:
对于多分类的dice loss:
GT需要进行one-hot编码,这样一个通道,像素点为(0-40)的GT就会变为四十个通道,每个通道像素点为(0,1),每个通道都可以看做一个二分类问题,属于该类别和不属于该类别。而pred的输出通道也为40,通过计算pred的每个通道和GT的每个通道的loss,最后求均值得到总loss。
3:代码代码参考
单分类代码:
import torch
import torch.nn as nn
class BinaryDiceLoss(nn.Model):
def __init__(self):
super(BinaryDiceLoss, self).__init__()
def forward(self, input, targets):
# 获取每个批次的大小 N
N = targets.size()[0]
# 平滑变量
smooth = 1
# 将宽高 reshape 到同一纬度
input_flat = input.view(N, -1)
targets_flat = targets.view(N, -1)
# 计算交集
intersection = input_flat * targets_flat
N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
# 计算一个批次中平均每张图的损失
loss = 1 - dice_eff.sum() / N
return loss
多分类代码:
import torch
import torch.nn as nn
class MultiClassDiceLoss(nn.Module):
def __init__(self, weight=None, ignore_index=None, **kwargs):
super(MultiClassDiceLoss, self).__init__()
self.weight = weight
self.ignore_index = ignore_index
self.kwargs = kwargs
def forward(self, input, target):
"""
input tesor of shape = (N, C, H, W)
target tensor of shape = (N, H, W)
"""
# 先将 target 进行 one-hot 处理,转换为 (N, C, H, W)
nclass = input.shape[1]
target = one_hot(target.long(), nclass)
assert input.shape == target.shape, "predict & target shape do not match"
binaryDiceLoss = BinaryDiceLoss()
total_loss = 0
# 归一化输出
logits = F.softmax(input, dim=1)
C = target.shape[1]
# 遍历 channel,得到每个类别的二分类 DiceLoss
for i in range(C):
dice_loss = binaryDiceLoss(logits[:, i], target[:, i])
total_loss += dice_loss
# 每个类别的平均 dice_loss
return total_loss / C