ernie是百度开发的一个预训练模型,模型结构类似bert,没错就是ernie成功带我入paddle
ernie1.0在bert的基础上mask做了改进如图,后续bert也使用这种方法比如wwm的模型:
ernie2.0在bert的基础上增加了更多的训练任务任务,目前还没开源
以下个人观点:
这些任务首先大致是pipeline的方法在各个任务上训练到最好,然后使用多任务训练的方法一起进行微调
ernie模型结构也是transformer应该和bert一样,只是在训练模型时增加了更多的任务。
paddle已经将很多预训练模型封装的很好的,调用也非常方便,但是一开始这个高度封装会有点晕,本文尽量从零开始,来跑一个ernie的分类模型,后续会介绍封装后模型的调用以及模型的部署。
还是三个流程:
1.数据处理(将数据变成可以放入到模型的格式)
2.模型构建(构建你想使用的模型)
3.训练和评估模型
1.数据处理
首先同样需要定义一个reader来读取数据和生成数据
这个部分和其它分类模型的区别是需要处理成ernie需要的格式
context部分包括:
1.token_ids :文本转成index,这里需要注意ernie会提供自己的字典,不需要自己生成字典,可以调用ernie自带的tokenization.py中的convert_tokens_to_ids函数来生成
tokenization.py文件包括token化和convert_tokens_to_ids
vocab_file=r'/ernie/vocab.txt' #字典文件
full_tokenize=FullTokenizer(vocab_file)
tokens=full_tokenize.tokenize("我出生于1960年,湖南人")
print(tokens)
print(full_tokenize.convert_tokens_to_ids(tokens))
2.text_type_ids:输入文本格式,一句话全是0,句子对就是[0,0,…0,1,1…,1]
3.position_ids:句子的绝对位置,代码如下(看代码更好懂)
position_ids = list(range(len(token_ids)))
以上部分就是模型需要的内容了,ernie2.0还要增加一个task_id(后续了)
ernie的reader做的工作比较多(这里细讲分类模型的reader):
首先reader会有一个BaseReader 这个是覆盖了分类,序列标注等任务的共同的基本操作主要包括:
1.读取文件
def csv_reader(fd, delimiter='\t'):
def gen():
for i in fd:
slots = i.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen()
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="utf8") as f:
reader = csv_reader(f, delimiter="\t")
headers = next(reader) #[label,text_a]
Example = namedtuple('Example', headers) #建立映射标签
examples = []
for line in reader:
example = Example(*line)
examples.append(example)
return examples
examples格式如下:
[Example(label=‘1’, text_a=‘去 逛街 咯’)]
2.将example装换成record
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
"""Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a)
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
if "text_b" in example._fields:
text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT/ERNIE is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
text_type_ids = []
tokens.append("[CLS]")
text_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
text_type_ids.append(0)
tokens.append("[SEP]")
text_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
text_type_ids.append(1)
tokens.append("[SEP]")
text_type_ids.append(1)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
qid = None
if "qid" in example._fields:
qid = example.qid
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
return record
3.对数据进行pad并生成batch