pytorch 加载预训练模型

pytorch的torchvision中给出了很多经典的预训练模型,模型的参数和权重都是在ImageNet数据集上训练好的

加载模型
方法一:直接使用预训练模型中的参数

import torchvision.models as models
model = models.resnet18(pretrained = True) #pretrained设为True,表示使用在ImageNet上训练好的参数

方法二:使用本地磁盘上的参数(直接下载的pth文件或者是在自己数据集上训练好的参数)

import torchvision.models as models
model = models.resnet18(pretrained = False) #pretrained设为False
state_dict = torch.load('resnet18.pth') #使用本地磁盘上的模型参数文件
model.load_state_dict(state_dict) #把读入的模型参数加载到模型中

修改模型
因为预训练模型是在ImageNet数据集上训练的,而ImageNet一共有1000个类别,如果我们要训练的数据集只有20个类别,这时就需要修改模型的全连接层

import torchvision.models as models
model = models.resnet18(pretrained=True)
num_classes = 20 #自己的数据集的类别
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, num_classes) #修改全连接层

总结

import torchvision.models as models
model = models.resnet18(pretrained=True)
for p in model.parameters(): 
    p.requires_grad = False #设为False表示只训练最后全连接层的权重,其余层不训练
num_classes = 20 #自己的数据集的类别
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, num_classes) #修改全连接层
model = nn.DataParallel(model).cuda() #gpu训练
model.eval() #排除BN和Dropout的影响

你可能感兴趣的:(深度学习)