一、前期工作
二、模型加载
三、模型训练
四、模型测试
大家好,我是微学AI,今天给大家带来一个基于BERT模型做文本分类的实战案例,在BERT模型基础上做微调,训练自己的数据集,相信之前大家很多都是套用别人的模型直接训练,或者直接用于预训练模型进行预测,没有训练和微调过大模型,因为像BERT这种大模型一般人是训练不了的,我们只能在大模型的基础上进行微调,或者做下游任务改造。
下面来介绍一下BERT模型,BERT是基于transfomer的预训练语言模型,它利用了transfomer中的编码器,进行数据编码,将文本数据转化为词向量。BERT核心内容是利用transfomer中的多头自注意力机制进行编码,关于transfomer的多头自注意力机制详细可以观看网络上的资料。
BERT模型是以两个NLP任务进行训练的,第一个任务是文本中词的预测,将已知训练文本隐掉词的信息,用MASK进行隐码,让模型去预测。第二个任务是在训练数据中随机抽取上下文关系句子或非上下文关系句子,让机器判断是否为上下文关系。BERT模型训练优势是无需进行标注数据。
我们可以利用BERT预训练模型进行下游任务改造,做自己相关任务,比如中文分词、文本分类,命名实体识别,阅读理解,情感分析,文本相似度、信息抽取等任务。
import torch
from datasets import load_dataset
import torch.nn.functional as F
from transformers import BertTokenizer
#加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')
#定义数据集
class Dataset(torch.utils.data.Dataset):
def __init__(self, split):
self.dataset = load_dataset(path='data', split=split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
text = self.dataset[i]['text']
label = self.dataset[i]['label']
return text, label
dataset = Dataset('train')
print(len(dataset), dataset[0])
def collate_fn(data):
sents = [i[0] for i in data]
labels = [i[1] for i in data]
#编码
data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
truncation=True,
padding='max_length',
max_length=500,
return_tensors='pt',
return_length=True)
#input_ids:编码之后的数字
#attention_mask:是补零的位置是0,其他位置是1
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = torch.LongTensor(labels)
#print(data['length'], data['length'].max())
return input_ids, attention_mask, token_type_ids, labels
#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=10,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for i, (input_ids, attention_mask, token_type_ids,
labels) in enumerate(loader):
break
print(len(loader))
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)
这里代码需要在同级文件夹下创建data 文件夹, 放入train.csv、test.csv数据集。
数据集格式如下:
我们可以在BERT输出端接入一个全连接层,输出2分类问题,也可加入CNN卷积层,这些可以自行操作。
from transformers import BertModel
#加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')
#不训练,不需要计算梯度
for param in pretrained.parameters():
param.requires_grad_(False)
#模型试算
out = pretrained(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
print(out.last_hidden_state.shape)
#定义下游任务模型
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(768, 2)
# 可加入CNN卷积层,可以自行操作
# self.conv1D = torch.nn.Conv1d(in_channels=500, out_channels=500, kernel_size=1)
# self.MaxPool1D = torch.nn.MaxPool1d(4, stride=2)
# self.Dropout = torch.nn.Dropout(p=0.5, inplace=False)
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1)
print(out.shape)
return out
model = Model()
print(model)
#model.summary()
model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids).shape
from transformers import AdamW
#训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model.train()
epochs = 30
for i, (input_ids, attention_mask, token_type_ids,
labels) in enumerate(loader):
out = model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
loss = criterion(out, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 1 == 0:
out = out.argmax(dim=1)
accuracy = (out == labels).sum().item() / len(labels)
print('epochs:',i, 'loss:',loss.item(),'accuracy:', accuracy)
if i == epochs:
torch.save(model, 'text_classfiy.model')
#model_load = torch.load('model/命名实体识别_中文.model')
break
#测试函数
def test():
model.eval()
correct = 0
total = 0
loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
batch_size=10,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for i, (input_ids, attention_mask, token_type_ids,
labels) in enumerate(loader_test):
if i == 5:
break
with torch.no_grad():
out = model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
out = out.argmax(dim=1)
correct += (out == labels).sum().item()
total += len(labels)
print(correct / total)
可以调用测试函数进行测试,看看模型训练效果。
欢迎继续关注 深度学习实战案例,持续更新。获取数据可私聊。