TableQA -- Tapas模型介绍

TableQA – TAPAS模型介绍

TAPAS 是2020年谷歌在ACL中的TAPAS: Weakly Supervised Table Parsing via Pre-training提出来的。
亮点
(1) TAPAS模型在tableqa问题中不需要生成逻辑表达式。
(2) TAPAS是采用弱监督方式训练,分别得到tableqa中的cell值和聚合函数

TAPAS 模型

TAPAS 是基于BERT encoder以及额外的表格位置embedding特征,模型加了两个分类层,分别选择表cell和聚合函数类型,这里将表拉平为词序列且进行tokenizer成token并且与问题拼接作为模型输入。整体模型结构如下图:
TableQA -- Tapas模型介绍_第1张图片

cell selection

cell selection 是线性分类层+伯努利分布选择表cell。这里是取bert 模型的last hidden state 向量做一个线性分类层得到每一个token的logits,table 单元cell的logits为单元cell中的token的logits值的平均值,这个线性层输出的概率是选择表cell的概率。在模型推理的时候,设置选择表cell的阈值,大于阈值表示选该cell。

Aggregation operator prediction

Aggregation operator prediction是得到聚合函数,例如MAX,MIN,AVG,COUNT等sql中常用的聚合函数。这里取bert模型中的last hidden state向量中的CLS向量做softmax得到聚合函数,根据聚合函数与cell selection的结果进行计算得到最终得tableqa得值。

模型输入特征

模型输入特征和bert模型输入一样,例如input_ids, attention_mask,token_type_ids,在TAPAS中加了Additional embeddings,
包括Position ID,Segment ID,Column / Row ID,Rank ID,Previous Answer等。下面详细介绍这些Additional embeddings。
Previous Answer:标记问题是否是上一个问题或者答案的一部分,这个是考虑了上下文,如果包含部分标记为1

a. 表cell 这里会根据最大长度进行截断,统计出每一个表cell的最大长度和最大行号,从而将表按照字符集别tokens 拉平,记录每一个token的行号,列号和token在cell tokens中的index 表行/列特征:按照最大长度截断后的行列编码拉平,这里列index从1开始,行加入了列名一行,从而构建了row_ids, columns_ids
b. position id 这个是数据自带,如果position id 非零表示该数据的答案与上一条数据的相关
c. segement id:将问题和相关的表拼接起来,问题标记为0,表标记为1
rank:关于float类型和日期型数据进行处理
d. column_ranks:统计每一列中cell中的数字以及行号,数字按照从小到达排序,确定数字在拉平的行列中的index,column_ranks在index位置处记录数字排序后的位置
e. inv_column_ranks:统计每一列中cell中的数字以及行号,数字按照从小到达排序,确定数字在拉平的行列中的index,inv_column_ranks在index位置处记录提取到的数字集合数量与数字排序后的位置的差
f. numeric_relations: 问题中的数字与表中的数字进行比较,这里比较结果统计为小于,大于,等于,将这个结果通过行index和列index得到拉平行列中的index,从而映射到numeric_relations相应的index中
g. labels 编码,答案来自于表,将答案的坐标转化为行列中的index,从而labels在index处标记为1
h. numeric_values_scale:统计表单元格中词出现大于1的次数,并将次数标记到相应的index位置处;是表的行列坐标在拉平的行列中的index的个数作为scale
i. numeric_values: 统计表中的数字,根据拉平的行列中的index将数字映射到numeric_values中的index处
备注
(1)这里的位置标记都是根据表的坐标位置对应到Column / Row ID获取到index,根据这些index进行标记
(2)transformers中给出了tapas模型,详细参见tokenizer_tapas.py
TableQA -- Tapas模型介绍_第2张图片

模型推理

模型输入已知问题和表,进行tokenizer,得到input_ids, attention_mask, token_type_ids,注意这里的token_type_ids是添加了Additional embeddings构成的,通过模型得到logits和agg_logits,通过伯努利分布以及row_ids和columns_ids得到表单元格的概率,选取大于阈值的单元格作为预测结果,根据agg 的结果最预测的单元格进行聚合运算。

详细参考transformers中的说明https://huggingface.co/docs/transformers/main/model_doc/tapas

思考

TAPAS适用于单张表的,问题涉及到多张表关联,这里如何修改?

实验结果

wikisql的数据上测试结果如下:acc = 0.7362
accuracy采用绝对相等,即答案和聚合函数都要预测正确记为预测正确

如有理解和书写不当,欢迎指正。

你可能感兴趣的:(NLP,自然语言处理,人工智能,nlp)