断断续续写了好久,本篇是最后一篇。太肝了,写这东西,写这东西收视率不高,实在不行就换实战把。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# 产生数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())
x , y =(Variable(x),Variable(y))
class Net(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net,self).__init__()
self.hidden1 = nn.Linear(n_input,n_hidden)
self.hidden2 = nn.Linear(n_hidden,n_hidden)
self.predict = nn.Linear(n_hidden,n_output)
def forward(self,input):
out = self.hidden1(input)
out = F.relu(out)
out = self.hidden2(out)
out = F.relu(out)
out =self.predict(out)
return out
net = Net(1,20,1)
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()
for t in range(200):
prediction = net(x)
loss = loss_func(prediction,y)
if t % 10 == 0:
print('LOSS:',loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net,'net.pkl')
torch.save(net.state_dict(),'net_parameter.pkl')
上述我们创建了一个三层全连接网络。并通过torch.save()进行保存。
其中:torch.save(net,path) —> 保存整个模型;
torch.save(net.state_dict(),path) --> 保存模型的参数。
这个系列文章是用来介绍nn.Module源码的,因此,我将介绍下Module是如何保存模型的。
这里torch.save(net,path)就是将整个net进行序列化了,没有介绍的东西。
这里介绍net.state_dict():
def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
def state_dict(self, destination=None, prefix='', keep_vars=False):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) # 递归调用
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
简单介绍下逻辑:首先创建一个有序字典,之后通过函数_save_to_state_dict()函数添加参数及buffer。然后递归每一个module。逐个将模型参数进行保存。Hook这一块后期介绍。
此处可以打印下print(net.state_dict())。会出现:
OrderedDict([('hidden1.weight', tensor([[-0.8167], [ 0.7373]])),
('hidden1.bias', tensor([0.2706, 0.4230])),
('hidden2.weight', tensor([[0.4133, 0.0363],[0.2908, 0.4775]])), ('hidden2.bias', tensor([-0.5193, 0.1890])),
('predict.weight', tensor([[-0.4416, 0.0453]])), ('predict.bias', tensor([-0.0529]))])
实际上就是一个有序字典。
加载模型就是使用:
torch.load(path) 加载 整个模型
net.load_state_dict(torch.load(path)) 加载模型参数
同上,我们看下net.load_state_dict源代码。
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {
k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {
k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
elif strict:
missing_keys.append(key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
strict: bool = True):
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {
} if metadata is None else metadata.get(prefix[:-1], {
})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(self)
load = None # break load->load reference cycle
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
简单介绍下思路:创建了三个list:missing_keys,unexpected_keys和error_msgs。这三个list最终若不为None,则说明加载过程中报错了。而加载过程也是递归module,之后遍历有序字典,将key和value一一对应。
其实前面三部分只介绍了两件事:
(1)保存模型
torch.save(net,path)
torch.save(net.state_dict(),path)
(2)加载模型
torch.load(net)
net.state_dict(torch.load(net))
但以上并不是在实际中最通用的。在实际训练网络过程中,在训练一定epoch后,会保存参数。但是存在一个问题,倘若由于特殊原因导致训练中断了,当时中断处的学习率以及epoch是多少的信息会丢失。因此,为了程序更加鲁棒,会同时保存epoch,参数及优化器的状态。Okay,以第一部分的代码为例,修改过代码为:
net = Net(1,2,1)
start_epoch = -1
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()
for epoch in range(start_epoch + 1,20):
prediction = net(x)
loss = loss_func(prediction,y)
if epoch % 5 == 0: # 假如每训练5轮打印此loss,并保存模型
print('LOSS:',loss.data)
checkpoint = {
'net': net.state_dict(), # 保存模型
'optimizer': optimizer.state_dict(), # 保存优化器
'epoch':epoch # 保存训练轮数
}
torch.save(checkpoint,'net%s.pkl'%(str(epoch)))
Okay,当然还差断点 加载 部分。我们在修改下,
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# 产生数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())
x , y =(Variable(x),Variable(y))
class Net(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net,self).__init__()
self.hidden1 = nn.Linear(n_input,n_hidden)
self.hidden2 = nn.Linear(n_hidden,n_hidden)
self.predict = nn.Linear(n_hidden,n_output)
def forward(self,input):
out = self.hidden1(input)
out = F.relu(out)
out = self.hidden2(out)
out = F.relu(out)
out =self.predict(out)
return out
net = Net(1,2,1)
start_epoch = -1
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()
Resume = True
if Resume: # 若加载
checkpoint = torch.load("/home/wujian/leleDetections/Save_and_Load/net10.pkl")
net.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器
start_epoch = checkpoint['epoch'] # 加载训练轮数
for epoch in range(start_epoch + 1,20):
prediction = net(x)
loss = loss_func(prediction,y)
if epoch % 5 == 0: # 假如每训练5轮打印此loss,并保存模型
print('epoch:%s,LOSS:%d'%(str(epoch),loss.data))
checkpoint = {
'net': net.state_dict(), # 保存模型
'optimizer': optimizer.state_dict(), # 保存优化器
'epoch':epoch # 保存训练轮数
}
torch.save(checkpoint,'net%s.pkl'%(str(epoch)))
optimizer.zero_grad()
loss.backward()
optimizer.step()
其中Resume参数控制是否载入模型。
在实际问题中,我们往往需要使用别人预训练模型,比如会使用官方提供的vgg.pth。但是假如自己实现相同vgg,加载vgg.pth肯定会出现keys不匹配错误。以简单例子为例:
checkpoint = torch.load('net_parameter.pkl')
print(checkpoint.keys())
输出结果为:
(['hidden1.weight', 'hidden1.bias', 'hidden2.weight', 'hidden2.bias', 'predict.weight', 'predict.bias'])
从这可以看出,我们网络命名方式为hidden1…
现在假如新建一个相同的net2,但各个模块命名方式不同。如何加载呢 ?
class Net1(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net1,self).__init__()
self.le1 = nn.Linear(n_input,n_hidden)
self.le2 = nn.Linear(n_hidden,n_hidden)
self.prele = nn.Linear(n_hidden,n_output)
def forward(self,input):
out = self.le1(input)
out = F.relu(out)
out = self.le2(out)
out = F.relu(out)
out =self.prele(out)
return out
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl') # 加载模型
net1.load_state_dict(checkpoint)
会出现如下错误(出现错误原因请看源码解析第二部分):
这里解决办法有两种:
(1)通过改写字典中keys名称,使其一一对应肯定就okay。
上代码:
class Net1(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net1,self).__init__()
self.le1 = nn.Linear(n_input,n_hidden)
self.le2 = nn.Linear(n_hidden,n_hidden)
self.prele = nn.Linear(n_hidden,n_output)
def forward(self,input):
out = self.le1(input)
out = F.relu(out)
out = self.le2(out)
out = F.relu(out)
out =self.prele(out)
return out
def load_from_state_dict(checkpoint,net):
ori_keys = checkpoint.keys() # 取出权重中key
now_keys = net.state_dict().keys() # 取出现在网络key
# 将权重的key重命名为现在网络中key
for ori_key,now_key in zip(list(ori_keys),list(now_keys)):
checkpoint[now_key] = checkpoint.pop(ori_key) # 更新字典中的键
return checkpoint
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')
checkpoint = load_from_state_dict(checkpoint,net1)
print('更改后键的名城:\n',checkpoint.keys())
net1.load_state_dict(checkpoint) # 加载模型
(2)控制load_state_dict(checkpoint,strict)中strict参数。
第二种方法相较于第一种方法更加简单,直接修改一个strict参数,另其为False即可:
class Net1(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net1,self).__init__()
self.le1 = nn.Linear(n_input,n_hidden)
self.le2 = nn.Linear(n_hidden,n_hidden)
self.prele = nn.Linear(n_hidden,n_output)
def forward(self,input):
out = self.le1(input)
out = F.relu(out)
out = self.le2(out)
out = F.relu(out)
out =self.prele(out)
return out
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')
net1.load_state_dict(checkpoint,strict=False)
原因:在源码中:stirct参数控制权重是否严格匹配新网络中键。默认为True,即严格匹配。若不匹配,则往missing_keys列表中添加错误key,然后导致RuntimeError。若另其为False,则不执行这部分,而仅仅考虑是否 权重 shape 匹配的问题。
elif strict:
missing_keys.append(key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
上部分考虑的是网络相同,但还是不够鲁棒。因为在实际任务中,往往预训练加载模型和实际模型是不同的。有可能只加载一部分权重,或者在新网络中添加新的结构。这个时候如何加载呢?
(1)比如下面代码,我在原有网络基础上又添加了一层,加载可直接通过strict = False。
class Net1(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net1,self).__init__()
self.le1 = nn.Linear(n_input,n_hidden)
self.le2 = nn.Linear(n_hidden,n_hidden)
self.prele = nn.Linear(n_hidden,n_output)
self.add_le = nn.Linear(n_output,n_output) # 添加了一层
def forward(self,input):
out = self.le1(input)
out = F.relu(out)
out = self.le2(out)
out = F.relu(out)
out =self.prele(out)
out = self.add_le(out)
return out
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')
print('原始权重:',checkpoint)
net1.load_state_dict(checkpoint,strict=False)
print('加载后网络权重:',list(net1.named_parameters()))
总结!!!:在加载模型过程中,实际都是直接令strict=False.而load_state_dict函数会自动比较权重中key和自定义网络中key。若相等,则就加载key对应的权重值。若两个key不等,则不加载。
其实深入考虑下:比如在alexnet解决10分类手写数字问题,若我现在想做个6分类任务,但我的net前半部分权重需要alexnet权重。因此,可以直接stirct=False,同时只要修改alexnet的分类层的key换个名字使其不要和alxenet中权重名字相同即可,核心代码如下:
pretrained_dict = torch.load('models/cifar10_statedict.pkl')
model_dict = model.state_dict()
print('随机初始化权重第一层:',model_dict['conv1.0.weight'])
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])
# 更新现有的model_dict
model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型
model.load_state_dict(model_dict)
总之一句话:上述代码只是一种思路,懂得原理实际上自己可以任意写,愿意加载哪一层就加载哪一层,不要拘泥。
(1)加载任意权重片段
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# override the _load_from_state_dict function
# convert the backbone weights pre-trained in Mask R-CNN
# use list(state_dict.keys()) to avoid
# RuntimeError: OrderedDict mutated during iteration
for key_name in list(state_dict.keys()):
key_changed = True
if key_name.startswith('backbone.'):
new_key_name = f'img_backbone{key_name[8:]}'
elif key_name.startswith('neck.'):
new_key_name = f'img_neck{key_name[4:]}'
elif key_name.startswith('rpn_head.'):
new_key_name = f'img_rpn_head{key_name[8:]}'
elif key_name.startswith('roi_head.'):
new_key_name = f'img_roi_head{key_name[8:]}'
else:
key_changed = False
if key_changed:
logger = get_root_logger()
print_log(
f'{key_name} renamed to be {new_key_name}', logger=logger)
state_dict[new_key_name] = state_dict.pop(key_name)
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
(2)多步长SGD训练:
#这里我设置了不同的epoch对应不同的学习率衰减,在10->20->30,学习率依次衰减为原来的0.1,即一个数量级
lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
#加载恢复
if RESUME:
path_checkpoint = "./model_parameter/test/ckpt_best_50.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler
#保存
for epoch in range(start_epoch+1,80):
optimizer.zero_grad()
optimizer.step()
lr_schedule.step()
if epoch %10 ==0:
print('epoch:',epoch)
print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
checkpoint = {
"net": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'lr_schedule': lr_schedule.state_dict()
}
if not os.path.isdir("./model_parameter/test"):
os.mkdir("./model_parameter/test")
torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))
https://zhuanlan.zhihu.com/p/133250753