这是我使用的项目地址:https://github.com/amdegroot/ssd.pytorch
问题描述:
这个问题是我在使用SSD做目标检测时遇到的,我要检测的目标有5种类别,所以我在data/config.py中的num_classes参数写了5,经过多方查找,发现了一个没注意到的细节,类别应该是5+1,那个1应该是背景。
还有一个原因就是标签的标号没有从0开始。
导入模型时:
RuntimeError: Error(s) in loading state_dict for SSD:
size mismatch for conf.0.weight: copying a param with shape torch.Size([24, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([20, 512, 3, 3]).
size mismatch for conf.0.bias: copying a param with shape torch.Size([24]) from checkpoint, the shape in current model is torch.Size([20]).
size mismatch for conf.1.weight: copying a param with shape torch.Size([36, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 1024, 3, 3]).
size mismatch for conf.1.bias: copying a param with shape torch.Size([36]) from checkpoint, the shape in current model is torch.Size([30]).
size mismatch for conf.2.weight: copying a param with shape torch.Size([36, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 512, 3, 3]).
size mismatch for conf.2.bias: copying a param with shape torch.Size([36]) from checkpoint, the shape in current model is torch.Size([30]).
size mismatch for conf.3.weight: copying a param with shape torch.Size([36, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 256, 3, 3]).
size mismatch for conf.3.bias: copying a param with shape torch.Size([36]) from checkpoint, the shape in current model is torch.Size([30]).
size mismatch for conf.4.weight: copying a param with shape torch.Size([24, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([20, 256, 3, 3]).
size mismatch for conf.4.bias: copying a param with shape torch.Size([24]) from checkpoint, the shape in current model is torch.Size([20]).
size mismatch for conf.5.weight: copying a param with shape torch.Size([24, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([20, 256, 3, 3]).
size mismatch for conf.5.bias: copying a param with shape torch.Size([24]) from checkpoint, the shape in current model is torch.Size([20]).
原因同样是类别导致,训练时写了多少类就应该是多少(包括背景)