def mask(batch_tokens,
seg_labels,
mask_word_tags,
total_token_num,
vocab_size,
CLS=1,
SEP=2,
MASK=3):
'''
:parambatch_tokens: 一个batch里的句子token id [batch_size, seq_len]
:paramseg_labels: 表示分词边界信息,0代表词首,1代表非词首,-1代表分隔符,[batch_size, seq_len]
:parammask_word_tags: a list of True or False, 表示mask word 还是 mask char, [batch_size]
:paramtotal_token_num: 总token数目, batch_size * seq_len
:paramvocab_size: 词典数
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
'''
max_len =max([len(sent)for sentin batch_tokens])
mask_label = []
mask_pos = []
# 生成每个单词被随机替换的概率
prob_mask = np.random.rand(total_token_num)
# 生成随机替换的单词id,Note: 第一个token是[CLS], 所以[low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
# 当前句子长度
pre_sent_len =0
# 当前累计token个数
prob_index =0
for sent_index, sentin enumerate(batch_tokens):# 每个句子遍历
mask_flag =False
mask_word = mask_word_tags[sent_index]
prob_index += pre_sent_len
if mask_word:# 如果是mask word级别(就是文中提到的短语和实体级别)
beg =0 # 单词开始index
for token_index, tokenin enumerate(sent):# 遍历当前句子的所有单词
seg_label = seg_labels[sent_index][token_index]
if seg_label ==1:# 非词首,跳过
continue
if beg ==0:# 找到起始点
if seg_label != -1:
beg = token_index
continue
prob = prob_mask[prob_index + beg]
# 15%的概率进行替换
if prob >0.15:
pass
else:
for indexin xrange(beg, token_index):# 一个单词或者一个0,1,1样式的短语
prob = prob_mask[prob_index + index]
base_prob =1.0
if index == beg:
base_prob =0.15
# 80%的概率替换为MASK
if base_prob *0.2 < prob <= base_prob:
mask_label.append(sent[index])
sent[index] = MASK
mask_flag =True
mask_pos.append(sent_index * max_len + index)
# 10%的概率替换为随机单词
elif base_prob *0.1 < prob <= base_prob *0.2:
mask_label.append(sent[index])
sent[index] = replace_ids[prob_index + index]
mask_flag =True
mask_pos.append(sent_index * max_len + index)
# 10%的概率不替换
else:
mask_label.append(sent[index])
mask_pos.append(sent_index * max_len + index)
if seg_label == -1:
beg =0
else:
beg = token_index
else:# mask char, 字级别替换
for token_index, tokenin enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob >0.15:
continue
elif 0.03 < prob <=0.15:
# mask
if token != SEPand token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag =True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <=0.03:
# random replace
if token != SEPand token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index +
token_index]
mask_flag =True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEPand token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len =len(sent)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos