Pytorch:BertModel使用

文章目录

  • 基本介绍
  • 简单例子:
  • 参考

基本介绍

  • 环境: Python 3.5+, Pytorch 0.4.1/1.0.0
  • 安装:
pip install pytorch-pretrained-bert
  • 必需参数:
    • --data_dir: "str": 数据根目录.目录下放着,train.xxx/dev.xxx/test.xxx三个数据文件.
    • --vocab_dir: "str": 词库文件地址.
    • --bert_model: "str": 存放着bert预训练好的模型. 需要是一个gz文件, 如"..x/xx/bert-base-chinese.tar.gz ", 里面包含一个bert_config.jsonpytorch_model.bin文件.
    • --task_name: "str": 用来选择对应数据集的参数,如"cola",对应着数据集.
    • --output_dir: "str": 模型预测结果和模型参数存储目录.

简单例子:

  • 导入所需包
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
  • 创建分词器
tokenizer = BertTokenizer.from_pretrained(--vocab_dir)
  • 需要参数: --vocab_dir, 数据样式见此
  • 拥有函数:
    • tokenize: 输入句子,根据--vocab_dir和贪心原则切词. 返回单词列表
    • convert_token_to_ids: 将切词后的列表转换为词库对应id列表.
    • convert_ids_to_tokens: 将id列表转换为单词列表.
text = '[CLS] 武松打老虎 [SEP] 你在哪 [SEP]'
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0,0,0,0, 1,1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

这里对标记符号的切词似乎有问题([cls]/[sep]), 而且中文bert是基于字级别编码的,因此切出来的都是一个一个汉字:

['[', 'cl', '##s', ']', '武', '松', '打', '老', '虎', '[', 'sep', ']', '你', '在', '哪', '[', 'sep', ']']
  • 创建bert模型并加载预训练模型:
model = BertModel.from_pretrained(--bert_model)
  • 放入GPU:
tokens_tensor = tokens_tensor.cuda()
segments_tensors = segments_tensors.cuda()
model.cuda()
  • 前向传播:
encoded_layers, pooled_output= model(tokens_tensor, segments_tensors)
  • 参数:
    • input_ids: (batch_size, sqe_len)代表输入实例的Tensor
    • token_type_ids=None: (batch_size, sqe_len)一个实例可以含有两个句子,这个相当于句子标记.
    • attention_mask=None: (batch_size*): 传入每个实例的长度,用于attention的mask.
    • output_all_encoded_layers=True: 控制是否输出所有encoder层的结果.
  • 返回值:
    • encoded_layer:长度为num_hidden_layers的(batch_sizesequence_lengthhidden_size)的Tensor.列表
    • pooled_output: (batch_size, hidden_size), 最后一层encoder的第一个词[CLS]经过Linear层和激活函数Tanh()后的Tensor. 其代表了句子信息

参考

  • GitHub - huggingface/pytorch-pretrained-BERT
  • 一起读Bert文本分类代码
  • 如何使用BERT实现中文的文本分类(附代码)
  • BERT-Pytorch demo初探

你可能感兴趣的:(信息科学)