最近训练自己的Landmark检测网络的时候,遇到的loss无法下降的问题,现在记录如下,一方面给自己留个记录,另一方面希望能够给大家一些参考。
主要就是使用的王井东老师团队的HRNet来跑我们自己做的数据集,不得不说HRNet的Pytorch训练流程对我们来说真的特别友好,有很多细节上的优化,一旦上手之后改很多东西都很方便。之前我已经在数据集上进行了验证,然后需要添加我们自己设计的模块,达到更好的效果,实现得特别快,跨年那天晚上写了五个小时就写完了,但是第二天兴致勃勃地一看,发现loss根本没有下降过,基本上如下图所示,我试了很多遍,但是就是死活不下降。(讲真我之前从没见过这么平的loss曲线)
从经验来说,一般Loss不下降的话,一般问题会出现在以下几个方面:
1、训练数据出错,并不是我们想要输入和输出的数据。
2、训练权重初始化问题(这种一般不会导致loss完全不下降,只会使拟合效果变差)
3、网络结构设计不合理导致梯度断掉。
以上三点是我首先想到的问题,所以着手开始逐步排查。
首先是训练数据,我将输入tensor转换成numpy格式然后保存成图片,发现每个输入都是对的,想要的标签也没问题,标签该有的峰值都有,就先排除了数据的问题。另外如上面说过的,预训练权重虽然很重要但是不会导致loss完全不下降。后来怀疑是梯度断掉了,但是我仔细检查了所有网络实现的模块,除了一个Tensor的Broadcast操作我不知道是否可导(如果有人知道是否可导的话请务必教教我,秋梨膏),其他的所有操作我敢保证一定是可导的。
这一来我就傻眼了,我目前能够想到的所有可能的问题所在我都排查过了,没有一个是符合现在这个情况的,但是有一点我可以说,肯定是我代码哪里出了问题,有哪个地方没有注意到。
因此,又仔细将自己的代码和原版HRNet进行一行一行的对比,终于有几行引起了我的注意。
Loss不下降的版本中,我是这样实现的(经过简化后):
for i, (target, input, meta) in train_loader:
model1_output = model1(input)
output = model2(input, model1_output)
loss = criterion(output, target)
loss.backward()
for optimizer in optimizers:
optimizer.zero_grad()
optimizer.step()
现在是两个网络模块,所以用了两个optimizer优化器,然后为了方便所以把这个写成了一个循环,先优化一个再优化另一个,我上是觉得没什么毛病,所以根本没往这个上面想。后来一想,loss完全不下降,跳动幅度那么小,问题应该只可能出现在优化器上啊。只是当时完全没意识到这个问题。
后来发现,HRNet里面实现的如下所示:
for i, (target, input) in train_loader:
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
说到这一对比大家应该能够发现问题了,问题就出在optimizer.zero_grad()这一行上。optimizer.zero_grad()的作用我就不仔细说了,这篇博客已经写的挺清楚了,主要就是要清零优化器所有参数的累计梯度,否则这个mini-batch中的梯度会累计到下一个mini-batch中,自然会直接导致模型训练的失败,后面将自己的代码修改如下:
for i, (target, input, meta) in train_loader:
for optimizer in optimizers:
optimizer.zero_grad()
model1_output = model1(input)
output = model2(input, model1_output)
loss = criterion(output, target)
loss.backward()
for optimizer in optimizers:
optimizer.step()
问题解决!!!
总结一下,Pytorch loss不下降有可能是mini-batch中的梯度没有清零,每个batch的更新步骤一定要按照下述流程来:
for i, (target, input) in train_loader:
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
即先梯度清零,再进行loss回传,然后再进行优化。
这篇博文写的有点长,把问题发生过程记录下来了,希望对你有一定帮助。