PyTorch加载预训练模型

  1. 加载单GPU模型
model = net()
pretrained_dict = torch.load("abc.pth")
model.load_state_dict(pretrained)
  1. 加载多GPU模型
model = net()
pretrained_dict = toch.load("m_abc.pth")
model.module.load_state_dict() # 多GPU要加module
  1. 加载部分预训练模型参数
model = net()
pretrained_dict = torch.load("abc.pth")
model_dict = model.state_dict()
# filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

你可能感兴趣的:(python)