loss为nan模型输出nan

最近,在做剪枝,之前没有接触过,用了gate-decorator-pruning工程,是19年的文章了,号称对resnet剪枝效果很好,如下:
loss为nan模型输出nan_第1张图片
可以看到文章中提出的GBN,相比其他剪枝方法,在FLOPs降低最多的情况下,同时保持了最高的精度。
这里记录一个自己实验过程中的坑,害的我找这个bug找了一天多,呜呜呜
loss为nan模型输出nan_第2张图片
问题:模型运行中train阶段,网络的输出logits第一次是正常的 ,第二次就是nan, 再经过loss, accuracy = pack.criterion(logits, label)后,loss为nan, 从而模型的输出也一直为nan,
定位到原因,我拿一个训练好的模型来剪枝,backbone最后输出为128维的特征向量,没用最后的fc linear层,代码中是自己加了一个权重,具体如下:

if cfg.base.head_model != "":
        model_dict = torch.load(cfg.base.head_model, map_location='cpu' if not cfg.base.cuda else 'cuda')
        data = model_dict['ARC_HEAD']['weight'].cpu().data.numpy()
        weight = nn.Parameter(torch.Tensor(data), requires_grad=False)
else:
    weight = nn.Parameter(torch.Tensor(cfg.model.class_number, 512))
if embedding.is_cuda:
    weight = weight.cuda()

我没有包含fc层的模型,所以设置为空,走了else分支,

weight = nn.Parameter(torch.Tensor(cfg.model.class_number, 512))

问题就出在这,你可以自己声明一个tensor看看里面的值,

>>> torch.Tensor(5, 8)
tensor([[-1.1774e-37,  6.0536e-43, -1.1774e-37,  6.0536e-43,  4.2039e-45,
          0.0000e+00,  1.7937e-43,  0.0000e+00],
        [ 1.4360e+04,  7.1846e+22,  1.4601e-19,  1.7750e+28,  6.8608e+22,
          2.8183e+20, -1.1774e-37,  6.0536e-43],
        [-1.1774e-37,  6.0536e-43, -1.1774e-37,  6.0536e-43,  1.4013e-45,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.3970e+20,  7.1833e+22,  1.8153e+31,  2.7372e+20,  5.5123e-11,
          4.6149e+24,  0.0000e+00,  0.0000e+00],
        [ 4.2039e-45,  0.0000e+00,  1.4026e-27,  4.5909e-41,  3.9821e-41,
          2.4428e-38,  1.0842e-19,  1.6930e+22]])

里面的值是极小或极大值,从而在训练及反向传播的时候,会出现类似梯度消失或梯度爆炸的情形,都会导致网络输出为nan, 进行如下修改,问题解决

weight = nn.Parameter(torch.randn([cfg.model.class_number, 128]))
>>> torch.randn([5,8])
tensor([[-0.4336,  0.3006,  1.4319, -0.4194, -0.7426, -0.2283, -0.2755,  0.3634],
        [ 0.9681,  0.5644,  0.3170, -0.9134, -1.7536, -0.0589,  0.4907,  1.3428],
        [ 1.0248,  1.2903,  0.3210,  1.9144,  0.0591, -0.5614,  1.7932, -1.0874],
        [ 0.7404, -1.1362, -1.1224, -1.1677, -0.2877,  1.5038, -0.0281, -0.9513],
        [ 0.3340, -0.1252,  1.2106, -1.4836, -1.3784,  0.8065, -0.0257,  1.9197]])

你可能感兴趣的:(深度学习,日常bug系列,深度学习,pytorch,神经网络)