Resnet 迁移学习记录

在实际应用中,cnn网络的训练是很繁琐且浪费时间的,这时候我们一般会去选择加载网上已经训练得很完善的网络作为自己的cnn网络层,下面例子为使用Resnet预训练模型来做自己的图片分类:

# 网络定义 
class Resnet(nn.Module):

    def __init__(self):
        super(Resnet, self).__init__()
        pretrained_net = torchvision.models.resnet18(pretrained=True)
        model =nn.Sequential(*list(pretrained_net.children())[:-1])
        self.model = model
        self.Linear = nn. Linear(in_features=512, out_features=10, bias=True)
    def forward(self, x):
        x=self.model(x)
        # 这里有个bug,在下载的预训练网络最后一层中,只显示了线性层,但是如果你直接添加一个线性层,会报错,原因为维度的不一致,需要view到适配维度。
        x = x.view(-1, 512)
        x=self.Linear(x)
        return x
X = torch.rand(size=(1, 3, 224, 224))
model=Resnet()
print(model)
print(model(X).shape)

然后进行训练,对比于之前的自己构建的网络重新训练来看,会发现收敛特别快,且很容易得到自己想要的ACC。

[1,   500] loss: 1.301
train_correct=
0.582
train time: 26.080318927764893
Accuracy on test set: 74.03  %
[1,  1000] loss: 0.772
train_correct=
0.7533333333333333
train time: 93.63580584526062
Accuracy on test set: 75.72  %
[1,  1500] loss: 0.697
train_correct=
0.7773333333333333
train time: 159.9850172996521
Accuracy on test set: 78.19  %

这里只跑了1500x3张图,连1/10个epoch都没跑到,但效果已经很强大了,且收敛得非常快速。

你可能感兴趣的:(pytorch深度学习,迁移学习,深度学习,pytorch)