Error(s) in loading state_dict for XXX Unexpected key(s) in state_dict, 找不到num_batches_tracked

今天在训练的时候发现加载模型的时候提示找不到num_batches_tracked,感到奇怪,因为之前已经成功训练过一次了怎么这次就报错了呢,后来发现,第一次训练的时候我用的是0.4.0的pytorch,这次用的是1.0的Pytorch,因为torch的版本不一样引起的问题

KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'

得到类似这样的报错

以下参考自这篇文章 https://zhuanlan.zhihu.com/p/91485607

经过研究发现,在pytorch 0.4.1及后面的版本里,BatchNorm层新增了num_batches_tracked参数,用来统计训练时的forward过的batch数目,源码如下(pytorch0.4.1): 

    if self.training and self.track_running_stats:
        self.num_batches_tracked += 1
        if self.momentum is None:  # use cumulative moving average
            exponential_average_factor = 1.0 / self.num_batches_tracked.item()
        else:  # use exponential moving average
            exponential_average_factor = self.momentum

知道原因就知道怎么处理了,我自己的模型里没有num_batches_tracked这个键,要把我预训练模型里的这个键给剔除掉

Error(s) in loading state_dict for XXX Unexpected key(s) in state_dict, 找不到num_batches_tracked_第1张图片

这是我对我文件里做的修改,注释掉的那行是原来的代码,可以对比一下 新增加的三行和原来的这行,就是简单的做了一个字典删除

你可能感兴趣的:(pytorch)