mmdetection-报错RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)

我在使用mmdetection修改neck.fpn的时候,遇到了以下报错:

1、RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same, but I set model and data to cuda

2、NotImplementedError

网上很多建议第1个报错将网络和模型都部署到GPU,但是我的问题不在这,后来经过一番检索(StackOverflow网友建议)发现,应该是nn.ModuleList与nn.Sequential的使用问题,mmdetetion中,可能是由于注册方式的原因(具体原因没有深入了解),在修改网络的时候,需要尽量避免使用nn.Sequential,应当优先使用nn.ModuleList。同时使用nn.ModuleList后,需要在forward中实现,具体差别和使用方法可参考文章——详解PyTorch中的ModuleList和Sequential,nn.ModuleList实现如下,我直接复制过来

class net_modlist(nn.Module):
    def __init__(self):
        super(net_modlist, self).__init__()
        self.modlist = nn.ModuleList([
                       nn.Conv2d(1, 20, 5),
                       nn.ReLU(),
                        nn.Conv2d(20, 64, 5),
                        nn.ReLU()
                        ])

    def forward(self, x):
        for m in self.modlist:
            x = m(x)
        return x

net_modlist = net_modlist()
print(net_modlist)
#net_modlist(
#  (modlist): ModuleList(
#    (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#    (1): ReLU()
#    (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#    (3): ReLU()
#  )
#)

for param in net_modlist.parameters():
    print(type(param.data), param.size())
# torch.Size([20, 1, 5, 5])
# torch.Size([20])
# torch.Size([64, 20, 5, 5])
# torch.Size([64]

我遇到的具体问题是:修改fpn的时候,给fpn增加其他结构,如果使用了nn.Sequential,可能会出现以下错误:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same, but I set model and data to cuda,如果没有在forward中正确实现nn.ModuleList,就会出现错误:NotImplementedError

PS:本来写在其他地方当作踩坑记录汇总,不过感觉这样写会更突出在mmdetection中遇到RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same, but I set model and data to cuda一点,希望能帮到遇到同样问题的朋友们,经常检索找不到还是比较烦人。此外,也可能可以使用nn.Sequential,只是我的使用方法不对,欢迎大家指正。

参考:

1、详解PyTorch中的ModuleList和Sequential
2、StackOverflow关于RuntimeError的解答

你可能感兴趣的:(深度学习编程使用记录,pytorch,深度学习,人工智能)