pytorch从预训练权重加载完全相同的层

saved = torch.load(cfg.Transfer, map_location=device)
# model.load_state_dict(saved['state_dict'])
old_state_dict = saved['state_dict']

# 新建一个空的字典,用于存储新模型加载的权重
new_state_dict = {}

# 将旧模型中相同层的权重复制到新模型中
all_layer=len(model.state_dict())
num=0
for key in model.state_dict():
    if key in old_state_dict and old_state_dict[key].shape==model.state_dict()[key].shape:
        new_state_dict[key] = old_state_dict[key]
        num+=1
    else:
        new_state_dict[key] = model.state_dict()[key]

# 使用load_state_dict()加载新模型的部分权重
model.load_state_dict(new_state_dict)
print("从预训练模型中加载了{}/{}层".format(num,all_layer))

你可能感兴趣的:(pytorch,人工智能,python)