在Pytorch框架下,在修改resnet进行finetune时遇到了这么一个错误:
TypeError: zip argument #1 must support iteration
参考了Bingoyear 老哥的博客,发现了问题所在,博客链接如下:
https://blog.csdn.net/angel_hben/article/details/105731015?utm_medium=distribute.pc_relevant.none-task-blog-baidujs-1
原来在我修改resnet时,将代码修改为:
class model_single(nn.Module):
def __init__(self, n_classes):
super(model_single, self).__init__()
self.basic_net = models.resnet18(pretrained=False)
num_features = self.resnet18.fc.in_features
self.basic_net.fc=nn.Linear(num_features, n_classes)
def forward(self, input):
feature = self.basic_net(input)
out = nn.Softmax(feature)
return out
正确的代码应该是:
class model_single(nn.Module):
def __init__(self,n_classes):
super(model_single, self).__init__()
self.basic_net = models.resnet18(pretrained=True)
num_features = self.basic_net.fc.in_features
self.basic_net.fc=nn.Sequential(
nn.Linear(num_features, n_classes),
nn.Softmax()
)
def forward(self, input):
out=self.basic_net(input)
return out
两者的区别在于,前者的返回是list,后者返回是tensor。
其原理是:在Pytorch中,nn.Sequentail()相当于一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行。