Pytroch加载部分权重

Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

一、 模型介绍

HardNet网络结构如下

def __init__(self):
        super(Hardnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128, affine=False),
            nn.ReLU(),

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128, affine=False),
            nn.ReLU(),

            nn.Dropout(0.3),
            nn.Conv2d(128, 128, kernel_size=8, bias=False),
            nn.BatchNorm2d(128, affine=False),
        )

        self.features.apply(weights_init)

        return


    def forward(self, input):
        features = self.features(input)
        return features

训练好的HardNet保存的权重名称

model-final-a1-e20.pt

将最后一层去掉得到HardNet的部分模型HardNet_old

def __init__(self):
        super(Hardnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128, affine=False),
            nn.ReLU(),

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128, affine=False),
            nn.ReLU(),
        )

二、 测试

使用方法为利用model.load_state_dict(pretrained_dict, strict=False)加载模型权重

import copy

#### 加载HardNet模型权重
checkpoint_name_matching = './model-final-a1-e20.pt'
checkpoint_matching = torch.load(checkpoint_name_matching, map_location={'cuda:0': 'cuda:0'})

net = Hardnet_old().cuda(0) # 该模型为修改后的HardNet(去掉最后一层)模型
net_old = copy.deepcopy(net.state_dict()) # Pytroch会自动初始化模型参数
# print(net_old.items())
net.load_state_dict(checkpoint_matching, strict=False)
net_new = copy.deepcopy(net.state_dict()) # 查阅HardNet_old模型参数
# net.load_state_dict(checkpoint_matching)
print(isinstance(net_old, dict), '\n') # True
print(isinstance(net_new, dict), '\n') # True

通过调试发现net_old和net_new模型参数不一样,该方法可行。如果将strict参数设置为True则会报错。

三、说明

pytroch是将权重转化为类似字典的结构,要求对应层的名称和key对应,不对应的话无法传递权重参数。例如上例中,如果只保留最后一层,利用model.load_state_dict(pretrained_dict, strict=False)加载模型权重得到的net_old和net_new中的模型参数就一样,表明参数没有传递进去,应为利用nn.Sequential进行设计网络得到的每一层的名称为设计的变量名(features)后面加点数字。
如果想要能够传递特定层的参数,建议分别对每一层进行命名。如下:

self.e1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
        )
        self.e2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
        )
        self.e3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
        )
        self.e4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
        )
        self.e5 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128, affine=False),
            nn.ReLU(),
        )
        self.e6 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128, affine=False),
            nn.ReLU(),
        )
        self.e7 = nn.Sequential(
            nn.Dropout(0.3),
            nn.Conv2d(128, 128, kernel_size=8, bias=False),
            nn.BatchNorm2d(128, affine=False),
        )

你可能感兴趣的:(Pytroch基础,深度学习,python,人工智能)