加载pytorch已有模型,修改最后分类头

在加载pytorch已有模型的时候,我们必须要明确的事情:

1 如何获取到pytorch所提供的模型,通过什么方式。
2 模型的结构,也就是模型的每个层的名字(key)。
3 我们要把需要加载的模型,尽量封装成一个类。

下面我们针对上面来给出答案。
答1:以 resnet18 举例

# ----------------1 导入库 -----------------
import torchvision.models as model
# -------------2 将resnet18导入到新模型。--------------
base_model = 'resnet18'
if 'resnet' in base_model:
	model = getattr(model,base_model)

答2 :我们在了解模型的时候,经常使用 dict(model.named_parameters()),它会返回一个字典,我们通过 字典.items()来得到字典的key和value值。我们要知道最后一层的分类层名字叫什么。

for (key,value) in dict(model.named_parameters()).items():
    print(key)

加载pytorch已有模型,修改最后分类头_第1张图片
最后一层的名字叫 fc ,这样我们可以通过最后一层的名字来修改最后一层。

num_class = 51
fc = getattr(model, 'fc')
feature_dim = fc.in_features
setattr(model,'fc',nn.Linear(feature_dim,num_class))
print(model)

在这里插入图片描述
这样就把最后一层修改完成了。

答3 :最后封装成新的模型类

import torchvision.models as model
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, num_class,base_model= 'resnet18'):
        super().__init__()
        self._prepare_base_model(num_class = num_class,base_model = base_model )
  
    def _prepare_base_model(self, base_model,num_class):  
        if 'resnet' in base_model:
            self.model = getattr(model, base_model)(pretrained=True)
            feature_dim = getattr(self.model, 'fc').in_features
            setattr(self.model,'fc',nn.Linear(feature_dim,num_class))      
        else:
            raise ValueError('Unknown base model: {}'.format(base_model))
            
    def forward(self,x):
        out = self.model(x)
        return out

net = Model(num_class=51,base_model='resnet18')
print(net)

在这里插入图片描述

你可能感兴趣的:(pytorch,深度学习,pytorch,深度学习)