解释说明:目前很多主流的网络模型主要包含backbone+其他结构(分类,回归),那么如何在训练自己的网络模型时使用别人已经训练好的网络模型权重呢??本文以Resnet50为例,构建一个基于resnet50的网络模型预训练过程。
torchvision中封装了Resnet系列、vgg系列、inception系列等网络模型,切内部给出了每个网络模型预训练权重的url路径
如下图所示,为torchvison官方封装的Resnet系列网络
解释说明:根据自己的理解,使用预训练权重过程主要包含以下几个步骤
模型权重载入完毕后,这是需要根据个人需要,训练时候选择更新网络全部参数还是冻结部分参数值更新后续的其他部分
下面开始撸代码
解释说明:这里我创建了一个基于resnet50网络的模型(这个网络是干什么的在此不做解释),网络结构如下
import torch
from torch.nn import Sequential, Conv2d, MaxPool2d, ReLU, BatchNorm2d
from torch import nn
from torch.utils import model_zoo
CLASS_NUM = 20 # 使用其他训练集需要更改
class Bottleneck(nn.Module): # 定义基本块
def __init__(self, in_channel, out_channel, stride, downsample):
super(Bottleneck, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.in_channel = in_channel
self.out_channel = out_channel
self.bottleneck = Sequential(
Conv2d(in_channel, out_channel, kernel_size=1, stride=stride[0], padding=0, bias=False),
BatchNorm2d(out_channel),
ReLU(inplace=True),
Conv2d(out_channel, out_channel, kernel_size=3, stride=stride[1], padding=1, bias=False),
BatchNorm2d(out_channel),
ReLU(inplace=True),
Conv2d(out_channel, out_channel * 4, kernel_size=1, stride=stride[2], padding=0, bias=False),
BatchNorm2d(out_channel * 4),
)
if self.downsample is False: # 如果 downsample = True则为Conv_Block 为False为Identity_Block
self.shortcut = Sequential()
else:
self.shortcut = Sequential(
Conv2d(self.in_channel, self.out_channel * 4, kernel_size=1, stride=stride[0], bias=False),
BatchNorm2d(self.out_channel * 4)
)
def forward(self, x):
out = self.bottleneck(x)
out += self.shortcut(x)
out = self.relu(out)
return out
class output_net(nn.Module):
# no expansion
# dilation = 2
# type B use 1x1 conv
expansion = 1
def __init__(self, in_planes, planes, stride=1, block_type='A'):
super(output_net, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=2, bias=False, dilation=2)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.downsample = nn.Sequential()
self.relu = nn.ReLU(inplace=True)
if stride != 1 or in_planes != self.expansion * planes or block_type == 'B':
self.downsample = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.downsample(x)
out = self.relu(out)
return out
class ResNet50(nn.Module):
def __init__(self, block):
super(ResNet50, self).__init__()
self.block = block
self.layer0 = Sequential(
Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
BatchNorm2d(64),
ReLU(inplace=True),
MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.layer1 = self.make_layer(self.block, channel=[64, 64], stride1=[1, 1, 1], stride2=[1, 1, 1], n_re=3)
self.layer2 = self.make_layer(self.block, channel=[256, 128], stride1=[2, 1, 1], stride2=[1, 1, 1], n_re=4)
self.layer3 = self.make_layer(self.block, channel=[512, 256], stride1=[2, 1, 1], stride2=[1, 1, 1], n_re=6)
self.layer4 = self.make_layer(self.block, channel=[1024, 512], stride1=[2, 1, 1], stride2=[1, 1, 1], n_re=3)
self.layer5 = self._make_output_layer(in_channels=2048)
self.avgpool = nn.AvgPool2d(2) # kernel_size = 2 , stride = 2
self.conv_end = nn.Conv2d(256, int(CLASS_NUM + 10), kernel_size=3, stride=1, padding=1, bias=False)
self.bn_end = nn.BatchNorm2d(int(CLASS_NUM + 10))
def make_layer(self, block, channel, stride1, stride2, n_re):
layers = []
for num_layer in range(0, n_re):
if num_layer == 0:
layers.append(block(channel[0], channel[1], stride1, downsample=True))
else:
layers.append(block(channel[1] * 4, channel[1], stride2, downsample=False))
return Sequential(*layers)
def _make_output_layer(self, in_channels):
layers = []
layers.append(
output_net(
in_planes=in_channels,
planes=256,
block_type='B'))
layers.append(
output_net(
in_planes=256,
planes=256,
block_type='A'))
layers.append(
output_net(
in_planes=256,
planes=256,
block_type='A'))
return nn.Sequential(*layers)
def forward(self, x):
# print(x.shape) # 3*448*448
out = self.layer0(x)
# print(out.shape) # 64*112*112
out = self.layer1(out)
# print(out.shape) # 256*112*112
out = self.layer2(out)
# print(out.shape) # 512*56*56
out = self.layer3(out)
# print(out.shape) # 1024*28*28
out = self.layer4(out) # 2048*14*14
out = self.layer5(out) # batch_size*256*14*14
out = self.avgpool(out) # batch_size*256*7*7
out = self.conv_end(out) # batch_size*30*7*7
out = self.bn_end(out)
out = torch.sigmoid(out)
out = out.permute(0, 2, 3, 1) # bitch_size*7*7*30
return out
def resnet50():
model = ResNet50(Bottleneck)
return model
通过下面代码,分别载入自己的网络模型和torch官方的网络模型,看看模型结构有什么不同
from torchvision import models
import torch
from new_resnet import resnet50
# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()
# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()
print(len(new_state_dict.keys())) # 输出torch官方网络模型字典长度
print(len(op.keys()))# 输出自己网络模型字典长度
从图中可以看出,torch官方网络模型主要有320个key,我们创建的网络模型有384个key
分别输出两种key有什么不同
from torchvision import models
import torch
from new_resnet import resnet50
# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()
# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()
print(len(new_state_dict.keys()))
print(len(op.keys()))
for i in new_state_dict.keys(): # 查看网络结构的名称 并且得出一共有320个key
print(i)
for j in op.keys(): # 查看网络结构的名称 并且得出一共有384个key
print(j)
从图中可以看出,我们创建的网络模型和torch官方的网络模型在前318层的结构都是一样的(即网络的backbone),官方的网络模型主要使用两层全连接层做分类,因此我们预训练是不需要这两层参数的,我们只要前面的backbone参数。
两种载入方式,通过2.1可以知道,网络的backbone结构是一样的,在318层后是不一样的。通过观察网络的key可以发现,torch官方的resnet网络模型的key名字和我们自己创建的基于resnet50网络模型的key名字不一样,因此参数的载入主要有两种:
当权重字典中的key名字一样时
from torchvision import models
import torch
from new_resnet import resnet50
# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()
# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()
# 将new_state_dict里不属于op的键剔除掉
pretrained_dict = {k: v for k, v in new_state_dict.items() if k in op}
# 更新现有的model_dict
op.update(pretrained_dict)
# 加载真正需要的state_dict
net.load_state_dict(op)
当权重字典中的key名字不一样时
from torchvision import models
import torch
from new_resnet import resnet50
# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()
# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()
# 无论名称是否相同都可以使用
for new_state_dict_num, new_state_dict_value in enumerate(new_state_dict.values()):
for op_num, op_key in enumerate(op.keys()):
if op_num == new_state_dict_num and op_num <= 317: # 320个key中不需要最后的全连接层的两个参数
op[op_key] = new_state_dict_value
net.load_state_dict(op) # 更改了state_dict的值记得把它导入网络中
从上面两种方式可以看出,第二种方式更适合我们。综上所述,参数的载入构成主要分为
解释说明:预训练参数载入后,我们可以选取在网络模型训练过程过程中,我们是选取让这部分参数参与参数更新,还是不参与参数更新。
如果参与参数更新的话直接进行后续的网络训练就行了,无处理操作
若不参与网络的更新,需要将参与网络更新的bool值设为False. 通过key.requires_grad获取当前字典参数的参与更新状态的bool值。对应的,在训练时候,optimizer里面只能更新requires_grad = True的参数,于是
from torchvision import models
import torch
from new_resnet import resnet50
# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()
# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()
# 无论名称是否相同都可以使用
for new_state_dict_num, new_state_dict_value in enumerate(new_state_dict.values()):
for op_num, op_key in enumerate(op.keys()):
if op_num == new_state_dict_num and op_num <= 317: # 320个key中不需要最后的全连接层的两个参数
op[op_key] = new_state_dict_value
net.load_state_dict(op) # 更改了state_dict的值记得把它导入网络中
for i, p in enumerate(net.parameters()): # 将前100层参数冻结
if i < 100:
p.requires_grad = False
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.001)