ERNIE掩码实现


def mask(batch_tokens,

        seg_labels,

        mask_word_tags,

        total_token_num,

        vocab_size,

        CLS=1,

        SEP=2,

        MASK=3):

'''

    :parambatch_tokens: 一个batch里的句子token id [batch_size, seq_len]

    :paramseg_labels: 表示分词边界信息,0代表词首,1代表非词首,-1代表分隔符,[batch_size, seq_len]

    :parammask_word_tags: a list of True or False, 表示mask word 还是 mask char, [batch_size]

    :paramtotal_token_num: 总token数目, batch_size * seq_len

    :paramvocab_size: 词典数

Add mask for batch_tokens, return out, mask_label, mask_pos;

Note: mask_pos responding the batch_tokens after padded;

'''

    max_len =max([len(sent)for sentin batch_tokens])

mask_label = []

mask_pos = []

# 生成每个单词被随机替换的概率

    prob_mask = np.random.rand(total_token_num)

# 生成随机替换的单词id,Note: 第一个token是[CLS], 所以[low=1]

    replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)

# 当前句子长度

    pre_sent_len =0

    # 当前累计token个数

    prob_index =0

    for sent_index, sentin enumerate(batch_tokens):# 每个句子遍历

        mask_flag =False

        mask_word = mask_word_tags[sent_index]

prob_index += pre_sent_len

if mask_word:# 如果是mask word级别(就是文中提到的短语和实体级别)

            beg =0  # 单词开始index

            for token_index, tokenin enumerate(sent):# 遍历当前句子的所有单词

                seg_label = seg_labels[sent_index][token_index]

if seg_label ==1:# 非词首,跳过

                    continue

                if beg ==0:# 找到起始点

                    if seg_label != -1:

beg = token_index

continue

                prob = prob_mask[prob_index + beg]

# 15%的概率进行替换

                if prob >0.15:

pass

                else:

for indexin xrange(beg, token_index):# 一个单词或者一个0,1,1样式的短语

                        prob = prob_mask[prob_index + index]

base_prob =1.0

                        if index == beg:

base_prob =0.15

                        # 80%的概率替换为MASK

                        if base_prob *0.2 < prob <= base_prob:

mask_label.append(sent[index])

sent[index] = MASK

mask_flag =True

                            mask_pos.append(sent_index * max_len + index)

# 10%的概率替换为随机单词

                        elif base_prob *0.1 < prob <= base_prob *0.2:

mask_label.append(sent[index])

sent[index] = replace_ids[prob_index + index]

mask_flag =True

                            mask_pos.append(sent_index * max_len + index)

# 10%的概率不替换

                        else:

mask_label.append(sent[index])

mask_pos.append(sent_index * max_len + index)

if seg_label == -1:

beg =0

                else:

beg = token_index

else:# mask char, 字级别替换

            for token_index, tokenin enumerate(sent):

prob = prob_mask[prob_index + token_index]

if prob >0.15:

continue

                elif 0.03 < prob <=0.15:

# mask

                    if token != SEPand token != CLS:

mask_label.append(sent[token_index])

sent[token_index] = MASK

mask_flag =True

                        mask_pos.append(sent_index * max_len + token_index)

elif 0.015 < prob <=0.03:

# random replace

                    if token != SEPand token != CLS:

mask_label.append(sent[token_index])

sent[token_index] = replace_ids[prob_index +

token_index]

mask_flag =True

                        mask_pos.append(sent_index * max_len + token_index)

else:

# keep the original token

                    if token != SEPand token != CLS:

mask_label.append(sent[token_index])

mask_pos.append(sent_index * max_len + token_index)

pre_sent_len =len(sent)

mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])

mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])

return batch_tokens, mask_label, mask_pos

你可能感兴趣的:(ERNIE掩码实现)