方式:将不更新的参数的requires_grad设置为Fasle,同时不将该参数传入optimizer.
(1)不更新的参数的requires_grad设置为Fasle
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
(2)不将该参数传入optimizer
# 定义一个fliter,只传入requires_grad=True的模型参数
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2)
总的代码:
# 最优写法
loss_fn = nn.CrossEntropyLoss()
# # 训练前的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print("model.fc1.weight.requires_grad:", model.fc1.weight.requires_grad)
print("model.fc2.weight.requires_grad:", model.fc2.weight.requires_grad)
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
optimizer = optim.SGD(filter(lambda p : p.requires_grad, model.parameters()), lr=1e-2) # 定义一个fliter,只传入requires_grad=True的模型参数
for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0,3,[3]).long()
output = model(x)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print("model.fc1.weight.requires_grad:", model.fc1.weight.requires_grad)
print("model.fc2.weight.requires_grad:", model.fc2.weight.requires_grad)
最优写法能够节省显存和提升速度:
节省显存:不将不更新的参数传入optimizer
提升速度:将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间
# ============================ step 4/5 优化器 ============================
# --------------------法2 : conv卷积层较小学习率,全连接层较大学习率----------------
# flag = 0
flag = 1 # 真
# Python程序语言指定任何非0和非空(null)值为true,0 或者 null为false。"判断条件"成立时/条件为“true"(非零),则执行后面的语句
if flag:
fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是fc层的parameters的内存地址
base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters()) # 将resnet18中的参数过滤掉fc层的参数后得到base_params==卷积层参数
optimizer = optim.SGD([
{'params': base_params, 'lr': LR*0.1}, # 第一个LR*0.1
{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9) # 10^-3,两个参数的momentum都是0.9
else:
optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_deca y_step, gamma=0.1) # 设置学习率下降策略
另外一个小技巧就是在nn.Module里冻结参数,这样前面的参数就是False,而后面的不变。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
for p in self.parameters():
p.requires_grad=False
self.fc1 = nn.Linear(16 * 5 * 5, 120)
唯一需要注意的是,pytorch 固定部分参数训练时需要在优化器中施加过滤。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
所有代码如下:finetune_resnet18.py
# -*- coding: utf-8 -*-
"""
# @file name : finetune_resnet18.py
# @brief : 模型finetune方法
"""
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from tools.my_dataset import AntsDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))
set_seed(1) # 设置随机种子
label_name = {"ants": 0, "bees": 1}
# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7
# ============================ step 1/5 数据 ============================
data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data")
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
# 2/3 加载预训练模型的参数
# flag = 0
flag = 1
if flag:
path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)
# ----------------------法1 : 冻结卷积层------------------------
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
for param in resnet18_ft.parameters():
param.requires_grad = False #参数不求取梯度,即参数不再更新
print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))
# 3/3 替换fc层,适应新任务:设置新的Linear层
num_ftrs = resnet18_ft.fc.in_features #获取原始resnet18模型fc层features的个数
# 构建Linear层的两个参数:输入神经元个数,输出神经元个数=
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
resnet18_ft.to(device) #将模型放到cpu or gpu
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
# --------------------法2 : conv卷积层较小学习率,全连接层较大学习率----------------
# flag = 0
flag = 1 # 真
# Python程序语言指定任何非0和非空(null)值为true,0 或者 null为false。"判断条件"成立时/条件为“true"(非零),则执行后面的语句
if flag:
fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是parameters的 内存地址
base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters()) # 将resnet18中的参数过滤掉fc层的参数后得到base_params
optimizer = optim.SGD([
{'params': base_params, 'lr': LR*0}, # 0
{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)
else:
optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_deca y_step, gamma=0.1) # 设置学习率下降策略
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
resnet18_ft.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) #将数据、标签放到cpu or gpu
outputs = resnet18_ft(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().cpu().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# if flag_m1:
print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
resnet18_ft.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = resnet18_ft(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
loss_val += loss.item()
loss_val_mean = loss_val/len(valid_loader)
valid_curve.append(loss_val_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
resnet18_ft.train()
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
感谢作者:【PyTorch框架】 迁移学习 & 模型微调_迁移学习模型微调_HUI 别摸鱼了的博客-CSDN博客
当有大部分层需要冻结,只有少部分层需要权重更新时,一层一层的写比较麻烦。那么就可以在网络定义时直接设置requires_grad=False,具体如下:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d()
self.conv2 = nn.Conv2d()
self.fc1 = nn.Squential(
nn.Linear(),
nn.Linear(),
ReLU(inplace=True),
)
for param in self.parameters():
param.requires_grad = False
#这样for循环之前的参数都被冻结,其后的正常更新。
self.classifier = nn.Linear()
同样,在定义过滤器时要进行过滤,如下:
目的是,告诉优化器,哪些需要更新,那些不需要,这一步至关重要!
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)