pytorch 迁移学习加载预训练模型,并进行修改

在迁移学习中,要加载预训练模型,如果是torch内置的一些模型网上有很多的方法很简单,但是当加载自己训练完成的模型以后,如何解决最后连接层输入维度不一致的问题,看了好几个帖子都不成功,最后看了一下,既然是将参数param转换成dict,不如直接将不需要的删掉即可。
于是可以用如下方法

	pretrained_params = torch.load(path)
    net = YOUR_OWN_MODEL()
    state_dict = pretrained_params.state_dict()
    del state_dict[xxx] #需要被删掉的参数
    net.load_state_dict(state_dict, strict=False)

简单有效!

你可能感兴趣的:(深度学习,tensorflow,深度学习,pytorch)