数据集如下:
为数据起个名字:bert_example.csv
"a stirring , funny and finally transporting re imagining of beauty and the beast and 1930s horror films",1
apparently reassembled from the cutting room floor of any given daytime soap,0
"they presume their audience wo n't sit still for a sociology lesson , however entertainingly presented , so they trot out the conventional science fiction elements of bug eyed monsters and futuristic women in skimpy clothes",0
"this is a visually stunning rumination on love , memory , history and the war between art and commerce",1
jonathan parker 's bartleby should have been the be all end all of the modern office anomie films,1
campanella gets the tone just right funny in the middle of sad in the middle of hopeful,1
a fan film that for the uninitiated plays better on video with the sound turned down,0
"b art and berling are both superb , while huppert is magnificent",1
"a little less extreme than in the past , with longer exposition sequences between them , and with fewer gags to break the tedium",0
the film is strictly routine,0
a lyrical metaphor for cultural and personal self discovery and a picaresque view of a little remembered world,1
the most repugnant adaptation of a classic text since roland joff and demi moore 's the scarlet letter,0
"for something as splendid looking as this particular film , the viewer expects something special but instead gets lrb sci fi rrb rehash",0
"this is a stunning film , a one of a kind tour de force",1
"may be more genial than ingenious , but it gets the job done",1
"there is a freedom to watching stunts that are this crude , this fast paced and this insane",1
"if the tuxedo actually were a suit , it would fit chan like a 99 bargain basement special",0
"as quiet , patient and tenacious as mr lopez himself , who approaches his difficult , endless work with remarkable serenity and discipline",1
final verdict you 've seen it all before,0
"blue crush follows the formula , but throws in too many conflicts to keep the story compelling",0
import torch
from torch import nn
from torch import optim
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertModel, BertTokenizer
import torch.utils.data as Data
import numpy as np
from loguru import logger
def load_data():
"""用来生成训练、测试数据"""
train_df = pd.read_csv("bert_example.csv", header=None)
sentences = train_df[0].values
targets = train_df[1].values
train_inputs, test_inputs, train_targets, test_targets = train_test_split(sentences, targets)
return train_inputs, test_inputs, train_targets, test_targets
class BertClassificationModel(nn.Module):
def __init__(self):
super(BertClassificationModel, self).__init__()
MODEL_PATH = 'bert_model' # 装着上面3个文件的文件夹位置
self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_PATH)
self.bert = BertModel.from_pretrained(MODEL_PATH) # 读取预训练模型
self.use_bert_classify = nn.Linear(768, 2) # bert预训练模型输出位768维,这里根据自己的分类任务可知为二分类,最后输出2个维度
self.sig_mod = nn.Sigmoid()
def forward(self, batch_sentences):
sentence_tokenized = self.tokenizer(batch_sentences,
truncation=True, # 超过最大长度截断
padding=True, # 设置长度不足就补齐
max_length=30, # 最大长度
add_special_tokens=True) # 添加默认的token
input_ids = torch.tensor(sentence_tokenized['input_ids']) # 转为token
attention_mask = torch.tensor(sentence_tokenized['attention_mask']) # attention mask
bert_output = self.bert(input_ids, attention_mask=attention_mask)
# hidden_state = bert_output[0].view(64, -1) # 还有一种方法:把隐层特征拉长
bert_cls_hidden_state = bert_output[0][:, 0, :] # 提取[CLS]对应的隐藏状态,这里等同于取每个序列的第一个位置输出
# 由于输入的[CLS]每个句子都一样,但是embedding后[CLS]就不一样了,因此我们认为[CLS]这个维度包含了句子的全部信息,即句向量
linear_output = self.use_bert_classify(bert_cls_hidden_state)
return self.sig_mod(linear_output)
def main():
train_inputs, test_inputs, train_targets, test_targets = load_data()
# ============== 参数 ================
epochs = 10
batch_size = 5
# ============== 参数 ================
train_sentence_loader = Data.DataLoader(
dataset=train_inputs,
batch_size=batch_size, # 每块的大小
)
train_label_loader = Data.DataLoader(
dataset=train_targets,
batch_size=batch_size,
)
bert_classifier_model = BertClassificationModel()
optimizer = optim.SGD(bert_classifier_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs): # 开始训练
loss_list = []
for sentences, labels in zip(train_sentence_loader, train_label_loader):
optimizer.zero_grad()
outputs = bert_classifier_model(sentences)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
loss_list.append(loss.detach().numpy())
logger.info("epoch:{},loss:{}".format(epoch, np.mean(loss_list)))
if __name__ == '__main__':
main()
得到效果:
2022-03-28 15:53:21.356 | INFO | __main__:main:73 - epoch:0,loss:0.6939845681190491
2022-03-28 15:53:24.096 | INFO | __main__:main:73 - epoch:1,loss:0.6804901957511902
2022-03-28 15:53:26.815 | INFO | __main__:main:73 - epoch:2,loss:0.6670143604278564
2022-03-28 15:53:29.475 | INFO | __main__:main:73 - epoch:3,loss:0.6514456868171692
2022-03-28 15:53:32.160 | INFO | __main__:main:73 - epoch:4,loss:0.6312667727470398
2022-03-28 15:53:34.832 | INFO | __main__:main:73 - epoch:5,loss:0.604450523853302
.......