联系方式:
e-mail: [email protected]
QQ: 973926198
github: https://github.com/FesianXu
如有谬误,请联系指正。转载请注明出处
pytorch
的时候,经常有需要使用一些通用的模型模块作为子模块,比如著名的
resnet
,
densenet
,
alexnet
,
inception
等等,在使用这些模型的时候,通常我们希望可以加载该模型在别的数据集,如
ImageNet
上进行训练后的权值参数,以便加速整个模型训练过程[1]。在此为了简便,称之为这些预训练模型为
基模型,加载基模型参数这个过程按照需求大致可以分为两类:
在对模型的参数进行fine-tune[2-3]的时候,按照需求也可以大致分为两类:
我们接下来基于pytorch框架[4]对其进行讨论。
在pytorch中,保存一个模型的参数特别容易,用torch.save()
即可,例如:
model = CNNNet(params)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
# here we train the models, skip these codes
saved_dict = {
'model': model.state_dict(),
'opt': opt.state_dict()
}
torch.save(saved_dict, './model.pth.tar')
我们发现,torch.save()
保存的是一个字典,其中的keys
可以自定义。这里有一点要注意的是,如果你用的优化器是例如Adam优化器
[5-6]这类内部有参数需要持久化的,最好也将其保存下来。
如果是需要加载整个模型,直接用torch.load()
和model.load_state_dict()
即可,如:
model = CNNNet(params)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
# yes you also need to define the model and optimizer
checkpoint = torch.load('./model.pth.tar')
# here, checkpoint is a dict with the keys you defined before
model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['opt'])
这个过程中torch.load()
只是负责读取模型参数,而用model.load_state_dict()
进行加载,这个加载是按照名字进行索引的,如果名字对不上或者是参数的形状,类型对不上,就会报错。我们可以打印出其名字进行观察,如:
for name in checkpoint ['model'].keys():
print(name)
输出如:
stgcn.weight_model.conv_models.0.conv_layer.weight
stgcn.weight_model.conv_models.0.conv_layer.bias
stgcn.weight_model.conv_models.1.conv_layer.weight
stgcn.weight_model.conv_models.1.conv_layer.bias
stgcn.weight_model.conv_models.1.batch_norm.weight
stgcn.weight_model.conv_models.1.batch_norm.bias
如果定义的模型和持久化的模型的参数名,形状,类型都能完全符合,就能正确加载。我们同时还注意到,变量的名字是由pytorch自行命名的,其命名根据就是你的变量名字,比如:
import torch.nn as nn
class model_A(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10,10)
class model_B(nn.Module):
def __init__(self):
super().__init__()
self.sub_model = model_A()
self.fc = nn.Linear(10,1)
model = model_B()
那么如果你打印prin(model_B)
,你就会发现子模块的名字为
sub_model.fc.weight
sub_model.fc.bias
fc.weight
fc.bias
我们观察到名字是以你的变量标识符命名的,这点和TensorFlow的命名机制完全不同,请注意。
同时我们还观察到,只要是权值weight
其命名后缀都是weight
,同样偏置bias
的后缀是bias
,因此根据此,可以单独对权值进行L2正则[7],具体过程见[8]。
根据上面的分析,我们便发现只要过滤掉不需要加载的模型的名字,即可实现部分模型加载了,例子如:
model = CNNNet()
checkpoint = torch.load('./model.pth.tar')
for name, params in model.stgcn.st_gcn_networks.named_parameters():
params_name = 'stgcn.st_gcn_networks.'+name
if params_name in model.state_dict():
model.state_dict()[params_name].copy_(checkpoint['model'][params_name])
我们发现,通过这个代码,我们可以仅对model.stgcn.st_gcn_networks
这个子模块的参数进行加载,而其他的参数保持初始化情况不变。
在模型的Fine-Tune(微调)或者联合调试过程中,我们经常需要固定某个模型的参数,而去调整其他模型的参数,主要方法有两个:
requires_grad=False
笔者在实践过程中最常用的是第三种方法,暂时只介绍第三种方法。代码很简单,例子如下:
trainable_vars = list(model.stgcn.weight_model.parameters())+ \
list(model.stgcn.fcn.parameters())+ \
list(model.stgcn.data_bn.parameters())+ \
list(model.stgcn.dim_map.parameters())+ \
list(model.aux_cls.parameters())
opt = torch.optim.SGD(trainable_vars , lr=1e-4, momentum=0.9)
简单粗暴,但是其实很好用,当需要训练的变量很多,而需要固定的变量很少的时候,可以用对整个模型参数求补的方式求得,这里不多介绍了。
当需要对整个模型进行微调时,只需要:
opt = torch.optim.SGD(model.parameters(), lr=1e-6, momentum=0.9)
在模型训练过程中,有些模块,比如对抗生成网络GAN[9]的生成器和判别器经常需要设置不同的学习率,以求得更好的效果或者不同模型之间的平衡。详细内容见我以前文章[8]中所述,这里不再累述。
pytorch内置有一些经常使用的模型和其在大规模数据集上的预训练参数,只需要安装了torchvision
便可轻松使用,模型有:
resnet
: resnet18
,resnet34
,resnet50
,resnet101
,resnet152
vgg
: vgg11
, vgg13
, vgg16
, vgg19
alexnet
densenet
inception
squeezenet
具体模型定义见:Github click me
使用方法很简单,如:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=False)
在这里如果指定pretrained=True
可以联网加载预训练模型,但是由于大陆因为你懂得原因,所以需要你懂得辅助工具,建议读者自行去下载模型的文件后手动加载。模型文件的地址可以在模型定义文件中找到,如[https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py]中的resnet模型:
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
[1]. Kaiming He, Ross Girshick, Piotr Dollár. Rethinking ImageNet Pre-training[J]. arXiv preprint, https://arxiv.org/abs/1811.08883
[2]. 迁移学习与fine-tuning有什么区别?
[3]. Fine tuning
[4]. pytorch
[5]. Adam 算法
[6]. Kingma D P, Ba J. Adam: A method for stochastic optimization[J]. arXiv preprint arXiv:1412.6980, 2014.
[7]. 曲线拟合问题与L2正则
[8]. pytorch中的L2和L1正则化,自定义优化器设置等操作
[9]. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.