有时候,我们加载一些预训练模型可能不想更新其中的某些参数,如在bert预训练模型中,我加载了 ‘bert-base-uncased’ 预训练模型(其是为下游任务提供特征提取的模型),但是这里用的是BertForSequenceClassification类去加载的话,它会在’bert-base-uncased’ 模型基础上加个分类层。假若我们不想fine-tune这个bert层,只想更新分类层,那就需要冻结bert层。
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)
for p in model.bert.parameters():
p.requires_grad = False
optim = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
___________________________________________________________________________
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.nn.Linear(100, 128)
self.fc2 = nn.nn.Linear(128, 64)
self.fc3 = nn.nn.Linear(64, 10)
def forward(self,x):x
pass
⭐假如我们要冻结其中的 self.fc2 层的参数:
model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.fc2.parameters():
para.requires_grad = False
⭐然后使用filter()函数过滤掉冻结的参数,传入优化器需要反向传播更新的参数:
import torch.optim as optim
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
⭐filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回由符合条件元素组成的新列表:
#function -- 判断函数。
#iterable -- 可迭代对象。
filter(function, iterable)
该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判断,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。