pytorch迁移部分模型参数到新模型

def func_converse():
    heads = {'hm': 1, 'vaf': 2, 'haf': 1}
    model = deletedres.ResBiNet(n_classes=1, heads=heads)
    save_model = torch.load(r'./resbinet/ressgd_child.pth')
    model_dict = model.state_dict()
    state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
    print(state_dict.keys())  # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
    torch.save(model.state_dict(), 'sgd.pth')

你可能感兴趣的:(PYTHON,pytorch,深度学习,人工智能)