pytorch冻结模型权重,不进行梯度反向传播
在 PyTorch 中冻结模型权重,不进行梯度反向传播,可以通过以下几个步骤实现:
requires_grad
属性为 False
。pythonCopy code
for param in model.parameters(): param.requires_grad = False
2.对于需要更新的参数,将其 requires_grad
属性设置为 True
。
pythonCopy code
for param in model.fc.parameters(): param.requires_grad = True
这里以 model.fc
为例,表示只对模型中最后一层全连接层的参数进行更新。
3.在训练过程中,将需要更新的参数传入优化器中。
pythonCopy code
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
在这里使用了 filter
函数,只选择 requires_grad=True
的参数进行优化,以避免对不需要更新的参数进行无效的计算。
4.在进行训练时,只对需要更新的参数进行反向传播和优化。
pythonCopy code
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
这里的 model(inputs)
表示对输入 inputs
进行前向计算,只有最后一层全连接层的参数需要更新,因此只对这一层的参数进行了反向传播和优化。
通过以上步骤,即可冻结模型权重,只对需要更新的参数进行优化。
完整的代码
以下是一个简单的示例代码,展示如何固定模型权重并不进行反向传播,只对最后一层全连接层的权重进行更新。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.fc = nn.Linear(64 * 5 * 5, 10)
# 固定卷积层的参数,不进行反向传播
for param in self.features.parameters():
param.requires_grad = False
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 加载数据
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
# 初始化模型和优化器
model = Net()
optimizer = optim.SGD(model.fc.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Epoch %d, loss: %.3f' % (epoch+1, running_loss/(i+1)))