该项目是将自然语言处理转化为mysql语句。
链接:https://github.com/ZhuiyiTechnology/nl2sql_baseline
首届中文NL2SQL挑战赛:https://tianchi.aliyun.com/competition/entrance/231716/introduction?spm=5176.12281949.1003.8.6f802448KX0Rys
就是 :
select $agg{0:"", 1:“AVG”, 2:“MAX”, 3:“MIN”, 4:“COUNT”, 5:“SUM”}
$column
where
$column $op{0:">", 1:"<", 2:"", 3:"!="}
conn_sql_dict{0:"", 1:“and”, 2:“or”}
$column $op{0:">", 1:"<", 2:"", 3:"!="}
…
可以将其拆解为4个子项目:
baseline项目的结果为:
可以看出,其他效果都比较好了,就是要解决W-Col,where的列,W-Op where后面的op,W-Val where后面的值。
拆解为四个部分
todo:具体实现log
输入文本
输入列columns
cond_num_score:个数
cond_col_score:列选择
cond_op_score:操作
cond_str_score:值
其实就是将输入文本embedding,columns的embedding,然后得到输出的各项分数。
baseline居然没有用列名的embedding?其他求解都用到了,这是啥原因?这个是不是导致精度不高的原因呢?
具体实现如下:
sel_num = np.argmax(sel_num_score[b])
max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
# find the most-probable columns' indexes
max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
cur_query['sel'].extend([int(i) for i in max_col_idxes])
cur_query['agg'].extend([i[0] for i in max_agg_idxes])
cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
# 拆分 cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
[x.data.cpu().numpy() for x in cond_score]
# 选择num
cond_num = np.argmax(cond_num_score[b])
# 总体
cond_num = np.argmax(cond_num_score[b])
all_toks = ['' ] + q[b] + ['' ]
max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
for idx in range(cond_num):
cur_cond = []
cur_cond.append(max_idxes[idx]) # where-col
cur_cond.append(np.argmax(cond_op_score[b][idx])) # where-op
cur_cond_str_toks = []
for str_score in cond_str_score[b][idx]:
str_tok = np.argmax(str_score[:len(all_toks)])
str_val = all_toks[str_tok]
if str_val == '' :
break
cur_cond_str_toks.append(str_val)
cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
cur_query['conds'].append(cur_cond)
总体思路是将列名和输入语句embedding然后得到预测个数,但是有几点需要注意:
col_inp_var.shape
torch.Size([599, 16, 300])
p col_name_len.shape
(599,)
p col_len.shape
(64,)
self.cond_num_name_enc
LSTM(300, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
e_num_col.shape
torch.Size([64, 17, 100])
p col_num.shape
(64,)
p self.cond_num_col_att
Linear(in_features=100, out_features=1, bias=True)
p num_col_att_val.shape
torch.Size([64, 17])
p num_col_att.shape
torch.Size([64, 17])
p num_col_att.unsqueeze(2).shape
torch.Size([64, 17, 1])
(e_num_col * num_col_att.unsqueeze(2)).shape
torch.Size([64, 17, 100])
(e_num_col * num_col_att.unsqueeze(2)).sum(1).shape
torch.Size([64, 100])
p K_num_col.shape
torch.Size([64, 100])
p cond_num_h1.shape
torch.Size([4, 64, 50])
p cond_num_h2.shape
torch.Size([4, 64, 50])
p self.cond_num_lstm
LSTM(300, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
p h_num_enc.shape
torch.Size([64, 57, 100])
p self.cond_num_att
Linear(in_features=100, out_features=1, bias=True)
p self.cond_num_att(h_num_enc).shape
torch.Size([64, 57, 1])
p self.cond_num_att(h_num_enc).squeeze().shape
torch.Size([64, 57])
p num_att_val.shape
torch.Size([64, 57])
p num_att.shape
torch.Size([64, 57])
p num_att.unsqueeze(2).shape
torch.Size([64, 57, 1])
p h_num_enc.shape
torch.Size([64, 57, 100])
p num_att.unsqueeze(2).expand_as(h_num_enc).shape
torch.Size([64, 57, 100])
p (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).shape
torch.Size([64, 57, 100])
p (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1).shape
torch.Size([64, 100])
p K_cond_num.shape
torch.Size([64, 100])
p self.cond_num_col2hid1
Linear(in_features=100, out_features=200, bias=True)
p self.cond_num_col2hid1(K_num_col).shape
torch.Size([64, 200])
p self.cond_num_col2hid1(K_num_col).view(B, 4, self.N_h/2).shape
torch.Size([64, 4, 50])
p self.cond_num_col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0, 1).shape
torch.Size([4, 64, 50])
p self.cond_num_col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0, 1).contiguous().shape
torch.Size([4, 64, 50])
p self.cond_num_out
Sequential(
(0): Linear(in_features=100, out_features=100, bias=True)
(1): Tanh()
(2): Linear(in_features=100, out_features=5, bias=True)
)
p cond_num_score.shape
torch.Size([64, 5])
与上面类似
q_seq:
gt_sel_num:
col_seq:
col_num
ans_seq:相当于where后面除了条件值都有了
ans_seq.append(
(
len(sql[‘sql’][‘agg’]),选择的列相应的聚合函数的个数, '0’代表无
sql[‘sql’][‘sel’],列
sql[‘sql’][‘agg’],选择的列相应的聚合函数, '0’代表无
conds_num,
tuple(x[0] for x in sql[‘sql’][‘conds’]),
tuple(x[1] for x in sql[‘sql’][‘conds’]),
sql[‘sql’][‘cond_conn_op’],
))
gt_cond_seq,就是原始的conds,前面两个是列,后面是值
gt_where_seq,原始提问句前后分别插入了《BEG》和《END》,然后原文如果可以找到答案,就返回[0,答案_start,答案_end,句子长度],原文找不到答案,就返回[0,句子长度]
gt_sel_seq,gt_sel_seq = [x[1] for x in ans_seq],就是单独列的id
附录:
{
"table_id": "a1b2c3d4", # 相应表格的id
"question": "世茂茂悦府新盘容积率大于1,请问它的套均面积是多少?", # 自然语言问句
"sql":{ # 真实SQL
"sel": [7], # SQL选择的列
"agg": [0], # 选择的列相应的聚合函数, '0'代表无
"cond_conn_op": 0, # 条件之间的关系
"conds": [
[1,2,"世茂茂悦府"], # 条件列, 条件类型, 条件值,col_1 == "世茂茂悦府"
[6,0,1]
]
}
}
#其中,SQL的表达字典说明如下:
op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
conn_sql_dict = {0:"", 1:"and", 2:"or"}
# q_seq: char-based sequence of question
# gt_sel_num: number of selected columns and aggregation functions
# col_seq: char-based column name
# col_num: number of headers in one table
# ans_seq: (sel, number of conds, sel list in conds, op list in conds)
# gt_cond_seq: ground truth of conds
question: 二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀
sql_string: {"agg": " SUM", "sel": " 票房占比(%)", "cond_conn_op": "or", "conds": "影片名称==大黄蜂影片名称==密室逃生"}
header: ['影片名称', '周票房(万)', '票房占比(%)', '场均人次']
id_train_tabel: {"rows": [["死侍2:我爱我家", 10637.3, 25.8, 5.0], ["白蛇:缘起", 10503.8, 25.4, 7.0], ["大黄蜂", 6426.6, 15.6, 6.0], ["密室逃生", 5841.4, 14.2, 6.0], ["“大”人物", 3322.9, 8.1, 5.0], ["家和万事惊", 635.2, 1.5, 25.0], ["钢铁飞龙之奥特曼崛起", 595.5, 1.4, 3.0], ["海王", 500.3, 1.2, 5.0], ["一条狗的回家路", 360.0, 0.9, 4.0], ["掠食城市", 356.6, 0.9, 3.0]], "name": "Table_4d29d0513aaa11e9b911f40f24344a08", "title": "表3:2019年第4周(2019.01.28 - 2019.02.03)全国电影票房TOP10", "header": ["影片名称", "周票房(万)", "票房占比(%)", "场均人次"], "common": "资料来源:艺恩电影智库,光大证券研究所", "id": "4d29d0513aaa11e9b911f40f24344a08", "types": ["text", "real", "real", "real"]}
question: 你好,我要查询一下涨跌幅超过20%的证券名称以及证券代码,谢谢
sql_string: {"agg": " ", "sel": " 证券名称 证券代码", "cond_conn_op": "", "conds": "涨跌幅(%)>20"}
header: ['证券代码', '证券名称', '涨跌幅(%)']
id_train_tabel: {"rows": [["300010.SZ", "立思辰", 13.13], ["300079.SZ", "数码科技", 5.56], ["002602.SZ", "世纪华通", 5.3], ["002640.SZ", "跨境通", 5.25], ["002555.SZ", "三七互娱", 5.19], ["600652.SH", "游久游戏", 23.31], ["002354.SZ", "天神娱乐", 23.21], ["601811.SH", "新华文轩", 21.15], ["300148.SZ", "天舟文化", 20.47], ["000673.SZ", "当代东方", 20.0]], "name": "Table_4d24aa113aaa11e9baa9f40f24344a08", "title": "图表1. A股传媒板块本周涨跌幅排行(2019.01.28-2019.02.01)", "header": ["证券代码", "证券名称", "涨跌幅(%)"], "common": "资料来源:万得,中银国际证券", "id": "4d24aa113aaa11e9baa9f40f24344a08", "types": ["text", "text", "real"]}