centernet损失函数修改记录

想将centernet中w,h的loss修正关联起来,采用iouloss进行优化,于是需修改loss函数,注意事项如下:

(1)loss函数本身到网络的输出层存在一个函数,这个函数求导,再逐层链式向前求导,完成训练。所以,自己写出的损失函数到pred一定要明确可导,不能乱写。

(2)写损失函数尽量使用torch自带的函数和+,-,*,/号进行操作,避免新建变量和for循环等操作,容易造成梯度无法反向传播。常用torch函数:

torch.clamp(),截断操作,限定最大最小值

torch.min, max

新建tensor需注意用原有输入tensor赋值方式:a=pred.new_tensor([0])

(3)正常操作流程如下:loss函数继承nn.module模块,编写forward过程。

(4)即便所有流程都编写正确,由于loss中存在的各种运算,还会出现计算出的梯度和weight不在一个量级,导致weight无法更新问题。比如,weight在1e-4量级,计算出的梯度在1e-10量级,这样会导致你训练几千次,weight也不会更新,进而loss没有任何改变。

(5)针对(4)中的问题,mmdet中的loss函数模板会有两个weight,一个weight为标量,可线性扩大缩小loss的数值;第二weight为tensor,与输出loss的维度一致,可精确控制每一个tensor元素对应loss的权重。

(6)即使通过(5)中操作,将loss放大较大范围,是的weight对应的grad在相同量级,可进行正常更新,还会存在问题:即损失函数设计不合理,weight较小时导致grad很小,梯度不更新,weight较大时,导致grad更新又跳变,训练发散。  总结就是:loss函数本身设计有问题,是无法通过简单的weight权重调整解决的。得尝试更换或重新设计loss函数。

你可能感兴趣的:(mmdetection,python,损失函数,pytorch,深度学习,神经网络)