基于Huggingface使用BERT进行文本分类的fine-tuning

 

随着BERT大火之后,很多BERT的变种,这里借用Huggingface工具来简单实现一个文本分类,从而进一步通过Huggingface来认识BERT的工程上的实现方法。

1、load data

 
  1. train_df = pd.read_csv('../data/train.tsv',delimiter='t',names=['text','label'])
  2. print(train_df.shape)
  3. train_df.head()
  4.  
  5. sentences = list(train_df['text'])
 
  1. targets =train_df['label'].values

2、token encodding

 
  1. #如果token要封装到自定义model类中的话,则需要指定max_len
  2. tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
  3. max_length=32
  4. sentences_tokened=tokenizer(sentences,padding=True,truncation=True,max_length=max_length,return_tensors='pt')
  5. targets=torch.tensor(targets)

3、encoding data

 
  1. # from torchvision import transforms,datasets
  2. from torch.utils.data import Dataset,DataLoader,random_split
  3.  
  4. class DataToDataset(Dataset):
  5. def __init__(self,encoding,labels):
  6. self.encoding=encoding
  7. self.labels=labels
  8.  
  9. def __len__(self):
  10. return len(self.labels)
  11.  
  12. def __getitem__(self,index):
  13. return self.encoding['input_ids'][index],self.encoding['attention_mask'][index],self.labels[index]
  14.  
  15. #封装数据
  16. datasets=DataToDataset(sentences_tokened,targets)
  17. train_size=int(len(datasets)*0.8)
  18. test_size=len(datasets)-train_size
  19. print([train_size,test_size])
  20. train_dataset,val_dataset=random_split(dataset=datasets,lengths=[train_size,test_size])
  21.  
  22. BATCH_SIZE=64
  23. #这里的num_workers要大于0
  24. train_loader=DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=5)
  25.  
  26. val_loader=DataLoader(dataset=val_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=5)#

4、create model

 
  1. class BertTextClassficationModel(nn.Module):
  2. def __init__(self):
  3. super(BertTextClassficationModel,self).__init__()
  4. self.bert=BertModel.from_pretrained('bert-base-uncased')
  5. self.dense=nn.Linear(768,2) #768 input, 2 output
  6.  
  7. def forward(self,ids,mask):
  8. out,_=self.bert(input_ids=ids,attention_mask=mask)
  9. out=self.dense(out[:,0,:])
  10. return out
  11.  
  12.  
  13. mymodel=BertTextClassficationModel()
  14.  
  15.  
  16. #获取gpu和cpu的设备信息
  17. device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  18. print("device=",device)
  19. if torch.cuda.device_count()>1:
  20. print("Let's use ",torch.cuda.device_count(),"GPUs!")
  21. mymodel=nn.DataParallel(mymodel)
  22. mymodel.to(device)

5、train model

 
  1. loss_func=nn.CrossEntropyLoss()
  2. optimizer=optim.Adam(mymodel.parameters(),lr=0.0001)
  3.  
  4. from sklearn.metrics import accuracy_score
  5. def flat_accuracy(preds,labels):
  6. pred_flat=np.argmax(preds,axis=1).flatten()
  7. labels_flat=labels.flatten()
  8. return accuracy_score(labels_flat,pred_flat)
  9.  
  10. epochs=3
  11. for epoch in range(epochs):
  12. train_loss = 0.0
  13. train_acc=0.0
  14. for i,data in enumerate(train_loader):
  15. input_ids,attention_mask,labels=[elem.to(device) for elem in data]
  16. #优化器置零
  17. optimizer.zero_grad()
  18. #得到模型的结果
  19. out=mymodel(input_ids,attention_mask)
  20. #计算误差
  21. loss=loss_func(out,labels)
  22. train_loss += loss.item()
  23. #误差反向传播
  24. loss.backward()
  25. #更新模型参数
  26. optimizer.step()
  27. #计算acc
  28. out=out.detach().numpy()
  29. labels=labels.detach().numpy()
  30. train_acc+=flat_accuracy(out,labels)
  31.  
  32. print("train %d/%d epochs Loss:%f, Acc:%f" %(epoch,epochs,train_loss/(i+1),train_acc/(i+1)))

6、evaluate

 
  1. print("evaluate...")
  2. val_loss=0
  3. val_acc=0
  4. mymodel.eval()
  5. for j,batch in enumerate(val_loader):
  6. val_input_ids,val_attention_mask,val_labels=[elem.to(device) for elem in batch]
  7. with torch.no_grad():
  8. pred=mymodel(val_input_ids,val_attention_mask)
  9. val_loss+=loss_func(pred,val_labels)
  10. pred=pred.detach().cpu().numpy()
  11. val_labels=val_labels.detach().cpu().numpy()
  12. val_acc+=flat_accuracy(pred,val_labels)
  13. print("evaluate loss:%d, Acc:%d" %(val_loss/len(val_loader),val_acc/len(val_loader)))
  14.  

 


程序员灯塔 
转载请注明原文链接:https://www.wangt.cc/2020/10/%e5%9f%ba%e4%ba%8ehuggingface%e4%bd%bf%e7%94%a8bert%e8%bf%9b%e8%a1%8c%e6%96%87%e6%9c%ac%e5%88%86%e7%b1%bb%e7%9a%84fine-tuning/

你可能感兴趣的:(基于Huggingface使用BERT进行文本分类的fine-tuning)