在深度学习训练网络模型的时候我们经常会使用到迁移学习,但是别人预训练好的模型往往不能直接拿来使用,比如ResNet模型,他们最后的输出是1000个类别,那么如果我想用于10分类,或者在分割模型上怎么使用呢?
我们先来看下分类的时候怎么迁移学习,以ResNet18为例。
import torch
import torch.nn as nn
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 替换最后一层全连接层
num_classes = 10 # 输出10个类别
model.fc = nn.Conv2d(512, num_classes, kernel_size=1)
# 冻结 ResNet18 的前面部分卷积层的参数
for name, param in model.named_parameters():
if not name.startswith('fc'):
param.requires_grad = False
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
# 训练模型
for epoch in range(num_epochs):
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
这里就以简单的显著性分割为例(最后的输出是(B, 1, H, W))。这里我们把全连接部分全部删除,直接让ResNet输出一个(B, 1, H, W)。
import torch
import torch.nn as nn
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 截断全连接层
modules = list(model.children())[:-1]
model = nn.Sequential(*modules)
# 添加卷积层和上采样层
channel = 1 # 二值分割输出为单通道
model.add_module('conv', nn.Conv2d(512, 256, kernel_size=3, padding=1))
model.add_module('relu', nn.ReLU(inplace=True))
model.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
model.add_module('conv_out', nn.Conv2d(256, num_channel, kernel_size=3, padding=1))
# 冻结 ResNet18 的前面部分卷积层的参数
for name, param in model.named_parameters():
if name.startswith('0') or name.startswith('1') or name.startswith('2') or name.startswith('3'):
param.requires_grad = False
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
在 PyTorch 中,我们可以通过以下两种方式来进行迁移学习:
冻结参数:我们可以冻结预训练模型的一些参数,只微调新添加的层的参数。这可以通过将参数的 requires_grad 属性设置为 False 来实现。
替换层:我们可以用新的层替换预训练模型的层,并且只微调新添加的层的参数。这可以通过从预训练模型中提取需要的层,并将其与新的层组合成一个新的模型来实现。