pytorch 去除pretrain model 最后一层或某一层

官方的pretrain model去除指定层可以参考链接https://blog.csdn.net/KHFlash/article/details/82345441,这里主要针对非官方的pretrain model,如下:

import torch
from collections import OrderedDict
import os
import torch.nn as nn
import torch.nn.init as init
from xxx import new_VGG

def init_weight(modules):
    for m in modules:
        if isinstance(m, nn.Conv2d):
            init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal(0,0.01)
            m.bias.data.zero_()

def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith('module'):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k,v in state_dict.items():
        name = ','.join(k.split('.')[state_idx:])
        new_state_dict[name] = v
    return new_state_dict

#加载pretrain model
state_dict = torch.load('/users/xxx/xxx.pth')

new_dict = copyStateDict(state_dict)
keys = []
for k,v in new_dict.items():
    if k.startswith('conv_cls'):    #将‘conv_cls’开头的key过滤掉,这里是要去除的层的key
        continue
    keys.append(k)

#去除指定层后的模型
new_dict = {k:new_dict[k] for k in keys}

net = new_VGG()   #自己定义的模型,但要保证前面保存的层和自定义的模型中的层一致

#加载pretrain model中的参数到新的模型中,此时自定义的层中是没有参数的,在使用的时候需要init_weight一下
net.state_dict().update(new_dict)

#保存去除指定层后的模型
torch.save(net.state_dict(), '/users/xxx/xxx.pth')

纯手打,代码在服务器上粘不下来,如果有错误,请指出,我做更正

你可能感兴趣的:(pytorch,pytorch)