关于pytorch直接加载resnet50模型及模型参数

1.由于与resnet50的分类数不一样,所以在调用时,要使用num_classes=分类数

model = torchvision.models.resnet50(pretrained=True,num_classes=5000)   #pretrained=True 既要加载网络模型结构,又要加载模型参数

如果需要加载模型本身的参数,需要使用pretrained=True

2.由于最后一层的分类数不一样,所以最后一层的参数数目也就不一样,所以在加载模型参数时要去掉最后一层

def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNet:
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        
        for k in list(state_dict.keys()):  #固定遍历对象
            print(k)
            if k == "fc.weight" or k == "fc.bias":
                state_dict.pop(k)  #删除最后一层的模型参数
         
        
        model.load_state_dict(state_dict,strict=False)  #非严格加载模型参数
    return model

由于字典中的元素是不固定的,所以在遍历的时候需要使用list,将其变为列表,这样元素位置就固定了,才可以进行后面的pop操作。

关于pytorch直接加载resnet50模型及模型参数_第1张图片

 由于没有加载最后一层,所以参数中需要加上strict=False

3.总结一下如何调用pytorch框架中已有的模型及其参数(如果是分类器,且最后一层分类数不一样)

a.实例化model

 b.点击resnet50,到源文件中去修改去除最后一层参数

你可能感兴趣的:(pytorch,深度学习,python)