菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理

系列目录:

  1. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)——
    数据
  2. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)——
    介绍及分词
    未完待续 … …

DuReader数据集为每个用户问题提供了大量的文档,对于常见的RC模型这些文档太长。基线系统中对于训练集和校验集选择了与答案最相关的段落,在推理时,选择与问题最相关的段落推理。另外,由于基线系统选用的模型是抽取型模型,也就是需要从原文中寻找答案的模型,所以预处理代码选取了F1值最大的答案、段落词块对作为为答案用于训练,处理策略在utils/preprocess.py中实现。

precision、recall、f1-score

预处理用到了精确度、召回率、f1分数指标,preprocess.py 文件在 precision_recall_f1 函数中实现了precision、recall、f1-score,基本思路是将预测答案文本与参考答案文本进行比较,如果预测答案文本中单词与参考答案文本中单词相同的数量越多说明预测的越准确。其具体计算公式为:

common = Counter(prediction_tokens) & Counter(ground_truth_tokens)

p r e c i s i o n = n u m s a m e n u m p r e d i c t i o n precision= \frac {num_{same}} {num_{prediction}} precision=numpredictionnumsame
r e c a l l = n u m s a m e n u m t r u t h recall= \frac {num_{same}} {num_{truth}} recall=numtruthnumsame
f 1 s c o r e = 2 ∗ p r e c i s i o n ∗ r e c a l l p r e c i s i o n + r e c a l l f1_score= \frac {2*precision*recall} {precision+recall} f1score=precision+recall2precisionrecall
式中: n u m s a m e num_{same} numsame 为预测答案与参考答案相同词语数量, n u m p r e d i c t i o n num_{prediction} numprediction 为预测答案词语数量, n u m t r u t h num_{truth} numtruth 为参考答案词语数量。具体代码如下:·

def precision_recall_f1(prediction, ground_truth):
    """
    计算并返回精确度precision, 召回率recall 和 F1分数f1-score
    Args:
        prediction: 预测答案字符串或词语列表
        ground_truth: 参考答案字符串或词语列表
    Returns:
       返回精确度p, 召回率r, F1分数f1
    Raises:
        None
    """
    #判断输入时字符串还是列表,如果是字符串则将其切分为词语列表。
    if not isinstance(prediction, list):
        prediction_tokens = prediction.split()
    else:
        prediction_tokens = prediction
    if not isinstance(ground_truth, list):
        ground_truth_tokens = ground_truth.split()
    else:
        ground_truth_tokens = ground_truth
    #计算预测答案与参考答案相同词语数量
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    #计算p,r,f1
    if num_same == 0:
        return 0, 0, 0
    p = 1.0 * num_same / len(prediction_tokens)
    r = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * p * r) / (p + r)
    return p, r, f1

调用结果如下:
菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理_第1张图片
由图可以看出,预测一与参考答案相似度更高,准确度、召回率、f1分数也更高。

选取最相关段落

代码片段如下:

#遍历所有文档
for doc in sample['documents']:
    most_related_para = -1
    most_related_para_len = 999999
    max_related_score = 0
    # 遍历所有段落
    for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
        if len(sample['segmented_answers']) > 0:
        	# 计算段落与多个参考答案的召回率并返回其中的最大值
            related_score = metric_max_over_ground_truths(recall,
                                                          para_tokens,
                                                          sample['segmented_answers'])
        else:
            continue
        #如果召回率最大,更新召回率与最相关段落,如果出现recall值相等的情况,取段落短的为最优段落。
        if related_score > max_related_score \
                or (related_score == max_related_score
                    and len(para_tokens) < most_related_para_len):
            most_related_para = p_idx
            most_related_para_len = len(para_tokens)
            max_related_score = related_score
    doc['most_related_para'] = most_related_para

基线系统中,选取最相关段落与截取伪答案在一个函数内,后面一块展示运行结果。

截取伪答案

由于基线系统选用的是抽取式模型,其输出结果需要从文档中截取答案。因此,训练之前需要首先根据参考答案从文档中截取伪答案来进行训练。伪答案的生成是对所选的最优段落,遍历词块,计算词块与答案集的F1值,得到F1值最大的词块,记录该词块所在的文档和该文档最优段落的起始和结束位置,作为伪答案范围。具体代码为:

#定义新增字段
sample['answer_docs'] = [] #答案所在文档索引
sample['answer_spans'] = [] #答案范围
sample['fake_answers'] = [] #伪答案
sample['match_scores'] = [] #伪答案分数

best_match_score = 0
best_match_d_idx, best_match_span = -1, [-1, -1]
best_fake_answer = None
#建立答案标记组
answer_tokens = set()
for segmented_answer in sample['segmented_answers']:
    answer_tokens = answer_tokens | set([token for token in segmented_answer])
#遍历文档
for d_idx, doc in enumerate(sample['documents']):
    if not doc['is_selected']:
        continue
    if doc['most_related_para'] == -1:
        doc['most_related_para'] = 0
    most_related_para_tokens = doc['segmented_paragraphs'][doc['most_related_para']][:1000]
    #遍历最相关段落,寻找f1_score最高文本块作为伪答案
    for start_tidx in range(len(most_related_para_tokens)):
        if most_related_para_tokens[start_tidx] not in answer_tokens:
            continue
        for end_tidx in range(len(most_related_para_tokens) - 1, start_tidx - 1, -1):
            span_tokens = most_related_para_tokens[start_tidx: end_tidx + 1]
            if len(sample['segmented_answers']) > 0:
                match_score = metric_max_over_ground_truths(f1_score, span_tokens,
                                                            sample['segmented_answers'])
            else:
                match_score = 0
            if match_score == 0:
                break
            if match_score > best_match_score:
                best_match_d_idx = d_idx
                best_match_span = [start_tidx, end_tidx]
                best_match_score = match_score
                best_fake_answer = ''.join(span_tokens)
if best_match_score > 0:
    sample['answer_docs'].append(best_match_d_idx)
    sample['answer_spans'].append(best_match_span)
    sample['fake_answers'].append(best_fake_answer)
    sample['match_scores'].append(best_match_score)

以上完整代码见DuReader基线代码库utils/processor.py文件。

调用展示

使用find_fake_answer(sample)调用后结果如下图所示:
菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理_第2张图片
我们可以发现预处理后的数据相比于原始数据增加了分词结果,并且在每篇文档中增加了与问题最相关的段落“most_related_para”字段;由于目前阅读理解框架都是基于Span抽取的,因此增加了“fake_answers”字段,表示伪答案,“answer_docs”字段表示伪答案来自于哪一篇文档,“answer_spans”字段表示伪答案所在文档的起始终止索引信息,“match_scores”表示伪答案的f1_score评分值。
参考文献:
DuReader数据集
DuReader Baseline Systems (基线系统)
DuReader:百度大规模的中文机器阅读理解数据集
DuReader数据集之数据预处理代码解析

你可能感兴趣的:(NLP,#,机器阅读理解)