pytorch如何精确冻结某一层的参数?

直接跳到后面简单举例看也行:

 有时候,我们加载一些预训练模型可能不想更新其中的某些参数,如在bert预训练模型中,我加载了 ‘bert-base-uncased’ 预训练模型(其是为下游任务提供特征提取的模型),但是这里用的是BertForSequenceClassification类去加载的话,它会在’bert-base-uncased’ 模型基础上加个分类层。假若我们不想fine-tune这个bert层,只想更新分类层,那就需要冻结bert层。

model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)

pytorch如何精确冻结某一层的参数?_第1张图片
冻结:

for p in model.bert.parameters():
                p.requires_grad = False
optim = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)             

___________________________________________________________________________

下面用个简单的例子举例,例如以下用torch搭建的一个模型:

import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
       	self.fc1 = nn.nn.Linear(100, 128)
       	self.fc2 = nn.nn.Linear(128, 64)
       	self.fc3 = nn.nn.Linear(64, 10)
       	
    def forward(self,x):x
    	pass


 ⭐假如我们要冻结其中的 self.fc2 层的参数:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.fc2.parameters():
	para.requires_grad = False


 ⭐然后使用filter()函数过滤掉冻结的参数,传入优化器需要反向传播更新的参数:

import torch.optim as optim
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)


 ⭐filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回由符合条件元素组成的新列表:

#function -- 判断函数。
#iterable -- 可迭代对象。
filter(function, iterable)

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判断,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

你可能感兴趣的:(NLP,pytorch,NLP,pytorch,深度学习)