我在写代码的时候看到很多代码有的使用以下这种方式导入
from pytorch_pretrained_bert import BertTokenizer,BertModel
有的使用transformer的方式导入的,所有我就有的时候有点郁闷究竟使用那种方式导入.
from transformers import BertTokenizer,BertConfig,BertModel
根据这个博主的博文,https://blog.csdn.net/qq_43391414/article/details/118252012
知道transerformers包包又名pytorch-transformers或者pytorch-pretrained-bert”
但是根据一些了解,实际上transformers库是最新的版本(以前称为pytorch-transformers和pytorch-pretrained-bert)
所以它在前两者的基础上对一些函数与方法进行了改进,包括一些函数可能只有在transformers库里才能使用,所以使用transformers库比较方便。
它提供了一系列的STOA(最先进)模型的实现,包括(Bert、XLNet、RoBERTa等)。
所以导入bert模型的时候推荐使用以下方式比较好
from transformers import BertTokenizer,BertModel
我在写代码的时候遇到的情况是,有的人的代码直接把句子分词,处理成id的格式输入到bert模型中,有的人要把数据处理成input_ids,mask_attention,token…各种各样的格式,五花八门的,作为一个小白进入干进入深度学习领域,感觉不是很友好.数据预处理这块,一个人的代码就有一千种写法,真的不知道相信谁.
所以我就看到了bert模型的源码,看到了他的foward函数:
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
这些就是我们要使用bert模型的输入数据格式:
1.input_ids:
2.attention_mask
3.token_type_ids:
4.position_ids:
5.head_mask:
6.inputs_embeds:
7.encoder_hidden_states:
8.encoder_attention_mask:
9.past_key_values:
10.use_cache:
11.output_attentions:
12.output_hidden_states:
13.return_dict:
====================================================
在这里可以使用Dataset先将数据处理好,放到这里面,然后将Dataset放到DataLoader中,设置好批次,一个批次一个批次的加载数据.
详细的写法可以看我写的笔记,Dataset和DataLoader的使用方法.
输入Bert模型中只要输入,input_ids,attention_mask,token_type_ids就可以了,下面只是我的部分代码.
out = self.bert(x['input_ids'], x['attention_mask'], x['token_type_ids'])
输入之后的输出out包括以下四个数据