模型=网络结构+网络参数
网络结构:VGG ,RASNet……
网络参数:网络结构中kernel,weight之类的数据
我们有时候会借鉴别人的网络,大多数情况下他们会在GitHub中放上已经训练好的模型,这个时候,你就可以下载下来直接用。
ps:大多都在README的文件中会放一个云盘的链接,点链接下载。
if ite_num % save_frq == 0:
print(model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
torch.save(net, model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
我们可以设置一个阈值,规定都是个跑多少个ite的时候保存模型。
保存模型的时候最好把模型的相关数据一起加上去,便于后期的查看。
保存模型:torch.save(net,path)
保存参数:torch.save(net.state_dict(),path)
个人觉的保存模型比保存参数好
因为模型包含参数,可以从模型中读取参数(下面第3点介绍)。
而且保存的pth大小差距不大。
下载模型:net=torch.load("D:/1.pth",map_location='cpu')
加载参数:
net=VGG() #首先你要得到相关网络的模型,里面的参数会自动随机初始化`
net.load_state_dict(torch.load("D:/2.pth",map_location='cpu) #更新得到更优的网络
shu_ju=net.state_dict()
先冻结参数,然后修改优化器。
冻结参数有两种方法:
1.在原来的参数基础上直接与改进的参数一起训练。
2.冻结原来的数据,训练改机的网络一段时间后再解冻一起训练。
这个时候就要用三.2中的加载参数了,不同的是,我们在load_state_dict的时候添加了False的关键字,表示不严格的加载,即只加载关键字相同的参数(这里的False的F记得大写。。。踩过坑)。
net=VGG() #首先你要得到相关网络的模型,里面的参数会自动随机初始化`
net.load_state_dict(torch.load("D:/2.pth",map_location='cpu,strict=False) #更新得到更优的网络
这样就可以把net扔进循环开始训练了
在修改某个卷积核的大小,发现设置不严格加载还是加载不了。
个人觉得,之前加载的上去是因为新网络中的k,v不存在于要加载的模型数据中,所以不会报错。
而这次我只改了卷积核大小,而对应的k,v还是存在要加载的模型数据中,在更新的时候匹配不上导致报错。
一种方法:
https://blog.csdn.net/hxxjxw/article/details/119491163
如果发生上述情况的话,那就需要把加载到的模型的中,不匹配的那几项删掉,然后加载其他项
x = torch.load(self.weight)
del x['char_recognizer.classifier.bias']
del x['char_recognizer.classifier.weight']
self.load_state_dict(x, strict=False)
法二(推荐):
model = ResNet(Bottleneck, [3, 4, 6, 3])
if pretrained:
state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-19c8e357.pth",
model_dir="./model_data")
pretrained_dict = {k: v for k, v in state_dict.items() if k in model}
model.load_state_dict(pretrained_dict)
model = ResNet(Bottleneck, [3, 4, 6, 3])
if pretrained:
state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-19c8e357.pth",
model_dir="./model_data")
pretrained_dict = {k: v for k, v in state_dict.items() if k in model}
model.load_state_dict(pretrained_dict)
# ----------------------------------------------------------------------------#
# 获取特征提取部分,从conv1到model.layer3,最终获得一个38,38,1024的特征层
# ----------------------------------------------------------------------------#
features = list([model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2, model.layer3])
# ----------------------------------------------------------------------------#
# 获取分类部分,从model.layer4到model.avgpool
# ----------------------------------------------------------------------------#
classifier = list([model.layer4, model.avgpool])
features = nn.Sequential(*features)
classifier = nn.Sequential(*classifier)
可以通过nn.Sequential(*features)的方式来从已经训练好的网络中获取对应的参数。
这样可以便于自己的训练,并且他们训练的网络的肯定比我们训练的要好。
更新(2022年5月15日14:08:16)
怎么冻结?——> 神经网络如何训练?——> 依靠方向传播 ——> 梯度
网络中有一个requires_grad的参数的可以更改,如果为False的话,就可以达到冻结效果。
for p in self.parameters():
p.requires_grad = False
比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话。
class RESNET_MF(nn.Module):
def __init__(self, model, pretrained):
super(RESNET_MF, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
...
我个人认为也可以这样:
法一:
net=load("D:/2.pth",map_loaction='cpu')
for key, value in net.named_parameters():
value.requires_grad = False
先直接把原来的模型中的requires_grad 全改为False,然后再更新到新的网络中。。
法二:
直接更新网络
net.load_state_dirc(torch.load(D:/1.pth),False)
#跳过改进的地方
for key, value in net.named_parameters():
if not ("beta" in key or "gamma" in key or 'alpha' in key):
value.requires_grad = True
在优化器中添加:
filter(lambda p: p.requires_grad, model.parameters())
用于过滤冻结的参数
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=1e-5)
训练一段时间后,我们就可以解冻,然后进行微调网络。
model的配置没有错误,但是就是加载不上。报错的时候,keys前面多了’model‘或者’modul’
那是多卡训练导致,可以参考下面这个链接
PyTorch单卡/多卡下模型保存/加载