torchvision.models 定义了用于处理不同任务的模型,包括图像分类、像素语义分割、对象检测、实例分割、人员关键点检测、视频分类和光流等。
在__init__文件下可以看到实现的网络列表,0.12版本实现的分类网络包括alexnet、convnext、resnet、vgg、squeezenet、inception、densenet、googlenet、mobilenet、mnasnet、shufflenet、efficientnet、regnet等,还实现了一些目标检测、特征提取、语义分割的经典网络。
toechvision.models通过class类型保存网络结构及其所需功能模块,以dict类型保存与训练权重的下载地址,并定义一个函数用于外部调用(实现网络参数定义、预训练权重加载等功能),以下为resnet.py的格式框架。
__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"]
# download pre-trained model
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}
# structure definition
class BasicBlock(nn.Module):...
class Bottleneck(nn.Module):...
class ResNet(nn.Module):...
# definition & initialization
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:...
在__init__文件下可以看到实现的网络列表,按功能分为分类网络、目标检测网络、
语义分割网络等。0.12版本实现的分类网络包括alexnet、convnext、resnet、vgg、squeezenet、inception、densenet、googlenet、mobilenet、mnasnet、shufflenet、efficientnet、regnet等。
目前的主流方法是利用分类网络的特征提取部分作为其他任务的特征提取网络,即删除分类网络最后的分类层(通常包括avgpool、flatten、full-connected层等),直接返回一张或多张特征图。
在project目录下建立net文件夹存储backbone的py文件,不要在torchvision.models下的源文件里直接改写
例如,新建nets.GoogleNet.py文件,复制粘贴torchvision.models.googlenet.py的全部内容。
根据最后定义的调用函数找到该网络的主干class,分析forward函数找到分类层,以resnet为例,最后三层分别对特征图进行自适应池化、一维化、全连接分类,属于网络的分类层,其输入即网络提取到的特征图,因此删除后三层。
分类层在__init__函数中的定义也要删除,否则网络初始化时依然会为其分配空间,浪费显存
class ResNet(nn.Module):
def __init__(self, ...):
super(ResNet, self).__init__()
...
def _forward_impl(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) # output feature map
# classification layer, delete here
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
如果你的目标检测网络只利用一层特征图进行目标预测,则对最后输出的特征图进行通道对齐即可。
如果目标检测网络使用了FPN等结构,需要多层特征图,则需要根据上述分类网络的结构进行分析,划分出三层或多层特征图进行输出。
特征层的选定有两种方法:
选定的输出层通道数应与neck部分的输入通道数相匹配,在此建议修改neck结构来适应backbone的输出,而非修改backbone的结构,以便最大化利用预训练的权重,提高迁移学习的效果。
在train.py中实现预训练权重加载,以下为使用torchvision.models实现model.backbone的权重加载。
参数strict=False
的作用是只加载结构匹配的参数,被删除的分类层、输出通道数被修改的特征提取层参数都不会被加载,因此建议在目标检测网络的neck部分修改通道参数,尽量不要改动backbone的结构,否则作为训练初期被冻结的backbone会导致网络训练效果不佳。
model_urls = {"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", }
model.backbone.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='./model_data/'), strict=False)
若训练中断,或选取一个epoch.pth作为起点重新训练,则需要加载整个网络的权重,加载代码如下
model_path = "logs/weights/Epoch_44.pth"
model_dict = model.state_dict()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)