pytorch:加载预训练模型中的部分参数,并固定该部分参数

pytorch:加载预训练模型中的部分参数,并固定该部分参数

https://www.jianshu.com/p/d67d62982a24

		initial_cnn = models.densenet121(pretrained=False)
        self.cnn = torch.nn.Sequential(*(list(initial_cnn.children())[:-1]))
        device = torch.device('cuda')
        cnn = self.cnn.to(device)
        cnn = nn.DataParallel(cnn)

        model_dict = cnn.state_dict()

        pretrained_dict = torch.load('epoch_11.pth')

        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

        model_dict.update(pretrained_dict)

        cnn.load_state_dict(model_dict)

你可能感兴趣的:(pytorch,深度学习,机器学习)