一篇简书写了一些常见的pytorch报错问题:https://www.jianshu.com/p/1fa86e060e5a
问题一:
Unexpected key(s) in state_dict: "bn1.num_batches_tracked", "layer1.0.bn1.num_batches_tracked", "layer1.0.bn2.num_batches_tracked", "layer1.0.bn3.num_batches_tracked", "layer1.0.downsample.1.num_batches_tracked", "layer1.1.bn1.num_batches_tracked", "layer1.1.bn2.num_batches_tracked", "layer1.1.bn3.num_batches_tracked", "layer1.2.bn1.num_batches_tracked", "layer1.2.bn2.num_batches_tracked", "layer1.2.bn3.num_batches_tracked", "layer2.0.bn1.num_batches_tracked", "layer2.0.bn2.num_batches_tracked", "layer2.0.bn3.num_batches_tracked", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.bn1.num_batches_tracked", "layer2.1.bn2.num_batches_tracked", "layer2.1.bn3.num_batches_tracked", "layer2.2.bn1.num_batches_tracked", "layer2.2.bn2.num_batches_tracked", "layer2.2.bn3.num_batches_tracked", "layer2.3.bn1.num_batches_tracked", "layer2.3.bn2.num_batches_tracked", "layer2.3.bn3.num_batches_tracked", "layer3.0.bn1.num_batches_tracked", "layer3.0.bn2.num_batches_tracked", "layer3.0.bn3.num_batches_tracked", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.bn1.num_batches_tracked", "layer3.1.bn2.num_batches_tracked", "layer3.1.bn3.num_batches_tracked", "layer3.2.bn1.num_batches_tracked", "layer3.2.bn2.num_batches_tracked", "layer3.2.bn3.num_batches_tracked", "layer3.3.bn1.num_batches_tracked", "layer3.3.bn2.num_batches_tracked", "layer3.3.bn3.num_batches_tracked", "layer3.4.bn1.num_batches_tracked", "layer3.4.bn2.num_batches_tracked", "layer3.4.bn3.num_batches_tracked", "layer3.5.bn1.num_batches_tracked", "layer3.5.bn2.num_batches_tracked", "layer3.5.bn3.num_batches_tracked", "layer4.0.bn1.num_batches_tracked", "layer4.0.bn2.num_batches_tracked", "layer4.0.bn3.num_batches_tracked", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.bn1.num_batches_tracked", "layer4.1.bn2.num_batches_tracked", "layer4.1.bn3.num_batches_tracked", "layer4.2.bn1.num_batches_tracked", "layer4.2.bn2.num_batches_tracked", "layer4.2.bn3.num_batches_tracked", "seg_up_1.conv.1.num_batches_tracked", "seg_up_2.conv.1.num_batches_tracked", "seg_up_3.conv.1.num_batches_tracked".
原因分析:使用pytorch0.4.1做训练,pytorch0.4.0做infer。主要问题是pytorch0.4.1添加了num_batches_tracked
修改位置:在torch/nn/modules/batchnorm.py文件当中第28行添加如下两行
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked',torch.LongTensor([0]))#fix
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_buffer('num_batches_tracked',None)#fix
问题二:pytorch遇到的问题:RuntimeError: randperm is only implemented for CPU
在torch/util/data/sample.py的51行当中加入如下内容
def __iter__(self):
cpu=torch.device('cpu')
return iter(torch.randperm(len(self.data_source),device=cpu).tolist())
问题3:
原因:pytorch版本问题,pytorch0.4.1不会报错,但是0.4.0会报错
解决办法:
https://github.com/amdegroot/ssd.pytorch/issues/109
优化函数问题
just change the net.parameters() to filter(lambda p: p.requires_grad,net.parameters())