实现继承BERT预训练模型的分类任务类
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertConfig
# 构建基于BERT的微调模型类
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
# 导入参数设置对象
model_config = BertConfig.from_pretrained(config.bert_path,
num_labels=config.num_classes)
# 导入基于bert-base-chinese的预训练模型
self.bert = BertModel.from_pretrained(config.bert_path, config=model_config)
# 此处用于调节是否将BERT纳入微调训练, 建议数据量+算力充足的情况下置为True
# 如果设置为False, 则保持整个BERT网络参数不变, 微调仅仅针对最后的全连接层进行训练
for param in self.bert.parameters():
param.requires_grad = True
# 全连接层的出口维度, 取决于具体的任务
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
# x[0]是输入的具体文本信息
context = x[0]
# x[1]是经过tokenizer处理后返回的attention mask张量
# mask的尺寸size和输入相同, padding部分用0遮掩, 比如[1, 1, 1, 0, 0]
mask = x[1]
# x[2]是字符类型id
token_type_ids = x[2]
# 利用BERT模型得到输出张量, 并且只保留BertPooler的输出, 即第一个字符CLS对应的输出张量
_, pooled = self.bert(context, attention_mask=mask, token_type_ids=token_type_id)
# 再利用微调网络进一步提取特征, 并利用全连接层对特征张量进行维度变换
out = self.fc(pooled)
return out
对BERT模型的参数执行微调
展示BERT模型中的参数命名:
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
# 将BERT中所有的参数层名字打印出来
for name, param in self.bert.named_parameters():
print(name)
self.fc = nn.Linear(config.hidden_size, config.num_classes)
针对BERT模型中的embedding层, 让其中的参数不参与微调
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
# 希望锁定embeddings层的参数, 不参与更新
for name, param in self.bert.embeddings.named_parameters():
print(name)
param.requires_grad = False
self.fc = nn.Linear(config.hidden_size, config.num_classes)
BERT中的全连接层, 让其中的weight参数不参与微调
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
# 希望将全连接层中的.weight部分参数锁定
for name, param in self.bert.named_parameters():
if name.endswith('weight'):
print(name)
param.requires_grad = False
self.fc = nn.Linear(config.hidden_size, config.num_classes)
BERT中指定的若干层, 让其中的参数不参与微调
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
# 封闭BERT中的第1, 3, 5层参数, 不参与微调
index_array = [1, 3, 5]
for name, param in self.bert.named_parameters():
new_x = name.split('.')[2]
if new_x in index_array:
print(name)
param.requires_grad = False
self.fc = nn.Linear(config.hidden_size, config.num_classes)