基于bert-base-chinese的二分类任务

使用hugging-face中的预训练语言模型bert-base-chinese来完成二分类任务,整体流程为:
1.定义数据集
2.加载词表和分词器
3.加载预训练模型
4.定义下游任务模型
5.训练下游任务模型
6.测试

具体代码如下:

1.定义数据集

import torch
from datasets import load_from_disk
class Dataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.dataset = load_from_disk(path)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text,label

dataset = Dataset('./data/ChnSentiCorp/train')
# print(dataset[0])

2.加载词表和分词器

from transformers import BertTokenizer

# 每个模型都有自己的tokenizer分词器
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path='E:/bert-base-chinese')

# 使用tokenizer编码数据
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


# 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break

# print(len(loader))

3.加载预训练模型

from transformers import BertModel

# 加载预训练模型
pretrained = BertModel.from_pretrained('E:/bert-base-chinese')
# 冻结bert预训练模型的参数,即不对预训练模型的参数进行训练
for param in pretrained.parameters():
    param.requires_grad_(False)
# 模型试算
out = pretrained(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids
)
# 输出最后一层隐藏层的形状,输出torch.Size([16, 500, 768]),batchsize是16,每个输入句子的长度是500,每个token的向量维度是768
# print(out.last_hidden_state.shape)

4.定义下游任务模型

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)  # 一个全连接神经网络,768是词编码维度,2是二分类

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():  # 使用预训练模型,抽取训练数据中的特征
            out = pretrained(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
        out = self.fc(out.last_hidden_state[:, 0])  # 把抽取到的特征放到全连接神经网络计算,获取bert最后一层隐藏层中[cls]对应的输出向量
        out = out.softmax(dim=1)  # 对out的第一个维度进行归一化

        return out

model = Model()

print(
    model(
        input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids
    ).shape
)

5.训练下游任务模型

from transformers import AdamW
# 训练,
optimizer = AdamW(model.parameters(), lr=5e-4)  # AdamW优化器
criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失函数,用于分类任务

model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)  # 模型的预测输出
    loss = criterion(out, labels)  # 用模型预测的输出和真实标签计算loss函数
    loss.backward()  # 反向传播
    optimizer.step()  # 梯度下降,
    optimizer.zero_grad()  # 梯度清零

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item / len(labels)  # 计算预测准确率

        print(i, loss.item(), accuracy)

    if i == 300:  # 训练300轮结束
        break

6.测试

def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(
        dataset=Dataset('./data/ChnSentiCorp/validation'),
        batch_size=32,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True
    )

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 5:  # 测试5轮
            break
        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct/total)

你可能感兴趣的:(预训练语言模型,bert,分类,python)