Pytorch中Balance binary cross entropy损失函数的写法

balance binary cross entropy损失函数在分割任务中很有用,因为分割任务会遇到正负样本不均的问题,甚至在边缘的分割任务重,样本不均衡达到了很高的比例。我们先来了解原理,再了解具体如何编程。

原理

比如一个预测结果,记作 P ∈ R H × W P \in R^{H \times W} PRH×W,对应的label是R,尺寸一样。
R中的1和0,即正负样本比例很不协调。我们想给这两个类别一个权重系数,乘在每一个像素计算的loss上。
这个系数的算法是:
a p o s = n u m n e g / ( H × W ) a_{pos} = num_{neg} / (H \times W) apos=numneg/(H×W)
a n e g = n u m p o s / ( H × W ) a_{neg} = num_{pos} / (H \times W) aneg=numpos/(H×W)
正样本的系数是通过负例的数目占总数目的比例得到,负样本的系数是正例的数目占总数目的比例得到。
至于为啥计算正样本的系数需要用负样本的比例,那是因为正样本数目少,就给一个大的比例,增大一下正样本的梯度,不至于负样本的梯度占统治地位(dominating),避免网络倾向于把样本判断为负样本。

代码

def bce2d(pred, gt, reduction='mean'):
	pos = torch.eq(gt, 1).float()
	neg= torch.eq(gt, 0).float()
	num_pos = torch.sum(pos)
	num_neg = torch.sum(neg)
	num_total = num_pos + num_neg
	alpha_pos = num_neg / num_total
	alpha_neg = num_pos / num_total
	weights = alpha_pos * pos + alpha_neg * neg
	return F.binary_cross_entropy_with_logits(pred, target, weights, reduction = reduction)

Note

有些情况下,正样本实在太少,负样本是在太多。这样情况下,负样本的权重系数就会接近0,使得训练出来的网络仍然是biased。我们可以通过在负样本系数上乘以一个大于1的值。
a n e g = 1.1 × n u m p o s / ( H × W ) a_{neg} = 1.1 \times num_{pos} / (H \times W) aneg=1.1×numpos/(H×W)

你可能感兴趣的:(Pytorch)