本文转载自 https://kexue.fm/archives/6771,加入了自己对代码的标注理解
import json
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
import codecs
from keras.layers import *
from keras.models import Model
import keras.backend as K
from keras.optimizers import Adam
from keras.callbacks import Callback
from tqdm import tqdm
import jieba
import editdistance
import re
import numpy as np
import tensorflow as tf
import keras
import pandas as pd
print(tf.__version__)
print(keras.__version__)
1.13.1
2.2.4
'''
{
"table_id": "a1b2c3d4", # 相应表格的id
"question": "世茂茂悦府新盘容积率大于1,请问它的套均面积是多少?", # 自然语言问句
"sql":{ # 真实SQL
"sel": [7], # SQL选择的列
"agg": [0], # 选择的列相应的聚合函数, '0'代表无
"cond_conn_op": 0, # 条件之间的关系
"conds": [
[1, 2, "世茂茂悦府"], # 条件列, 条件类型, 条件值,col_1 == "世茂茂悦府"
[6, 0, "1"]
]
}
}
# 其中条件运算符、聚合符、连接符分别如下
op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
conn_sql_dict = {0:"", 1:"and", 2:"or"}
'''
maxlen = 160
num_agg = 7
num_op = 5
num_cond_conn_op = 3
learning_rate = 5e-5
min_learning_rate = 1e-5
config_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\bert_config.json'
checkpoint_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\bert_model.ckpt'
dict_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\vocab.txt'
def read_data(data_file, table_file):
data, tables = [], {}
with open(data_file,encoding='UTF-8') as f:
for l in f:
data.append(json.loads(l))
with open(table_file,encoding='UTF-8') as f:
for l in f:
l = json.loads(l)
d = {}
d['headers'] = l['header']
d['header2id'] = {j: i for i, j in enumerate(d['headers'])}
d['content'] = {}
d['all_values'] = set()
rows = np.array(l['rows'])
for i, h in enumerate(d['headers']):
d['content'][h] = set(rows[:, i])
d['all_values'].update(d['content'][h])
d['all_values'] = set([i for i in d['all_values'] if hasattr(i, '__len__')])
tables[l['id']] = d
return data, tables
train_data, train_tables = read_data('E:/zym_test/test/nlp/data/train/train.json','E:/zym_test/test/nlp/data/train/train.tables.json')
valid_data, valid_tables = read_data('E:/zym_test/test/nlp/data/val/val.json','E:/zym_test/test/nlp/data/val/val.tables.json')
test_data, test_tables = read_data('E:/zym_test/test/nlp/data/test/test.json','E:/zym_test/test/nlp/data/test/test.tables.json')
train_data[0:4]
[{'table_id': '4d29d0513aaa11e9b911f40f24344a08',
'question': '二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀',
'sql': {'agg': [5],
'cond_conn_op': 2,
'sel': [2],
'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}},
{'table_id': '4d29d0513aaa11e9b911f40f24344a08',
'question': '你好,你知道今年第四周密室逃生,还有那部大黄蜂它们票房总的占比吗',
'sql': {'agg': [5],
'cond_conn_op': 2,
'sel': [2],
'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}},
{'table_id': '4d29d0513aaa11e9b911f40f24344a08',
'question': '我想你帮我查一下第四周大黄蜂,还有密室逃生这两部电影票房的占比加起来会是多少来着',
'sql': {'agg': [5],
'cond_conn_op': 2,
'sel': [2],
'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}},
{'table_id': '4d25e6403aaa11e9bdbbf40f24344a08',
'question': '有几家传媒公司16年为了融资收购其他资产而进行定增的呀',
'sql': {'agg': [4],
'cond_conn_op': 1,
'sel': [1],
'conds': [[6, 2, '2016'], [7, 2, '融资收购其他资产']]}}]
train_tables[ '4d29d0513aaa11e9b911f40f24344a08' ]
{'headers': ['影片名称', '周票房(万)', '票房占比(%)', '场均人次'],
'header2id': {'影片名称': 0, '周票房(万)': 1, '票房占比(%)': 2, '场均人次': 3},
'content': {'影片名称': {'“大”人物',
'一条狗的回家路',
'大黄蜂',
'家和万事惊',
'密室逃生',
'掠食城市',
'死侍2:我爱我家',
'海王',
'白蛇:缘起',
'钢铁飞龙之奥特曼崛起'},
'周票房(万)': {'10503.8',
'10637.3',
'3322.9',
'356.6',
'360.0',
'500.3',
'5841.4',
'595.5',
'635.2',
'6426.6'},
'票房占比(%)': {'0.9',
'1.2',
'1.4',
'1.5',
'14.2',
'15.6',
'25.4',
'25.8',
'8.1'},
'场均人次': {'25.0', '3.0', '4.0', '5.0', '6.0', '7.0'}},
'all_values': {'0.9',
'1.2',
'1.4',
'1.5',
'10503.8',
'10637.3',
'14.2',
'15.6',
'25.0',
'25.4',
'25.8',
'3.0',
'3322.9',
'356.6',
'360.0',
'4.0',
'5.0',
'500.3',
'5841.4',
'595.5',
'6.0',
'635.2',
'6426.6',
'7.0',
'8.1',
'“大”人物',
'一条狗的回家路',
'大黄蜂',
'家和万事惊',
'密室逃生',
'掠食城市',
'死侍2:我爱我家',
'海王',
'白蛇:缘起',
'钢铁飞龙之奥特曼崛起'}}
train_tables['4d25e6403aaa11e9bdbbf40f24344a08']
{'headers': ['证券代码',
'证券简称',
'最新收盘价',
'定增价除权后至今价格',
'增发价格',
'倒挂率',
'定增年度',
'增发目的'],
'header2id': {'证券代码': 0,
'证券简称': 1,
'最新收盘价': 2,
'定增价除权后至今价格': 3,
'增发价格': 4,
'倒挂率': 5,
'定增年度': 6,
'增发目的': 7},
'content': {'证券代码': {'300148.SZ', '300182.SZ', '300269.SZ'},
'证券简称': {'天舟文化', '捷成股份', '联建光电'},
'最新收盘价': {'4.09', '4.69', '5.48'},
'定增价除权后至今价格': {'11.16', '11.29', '12.48', '21.88', '23.07', '9.91'},
'增发价格': {'14.78', '15.09', '16.34', '16.988', '22.09', '23.3004'},
'倒挂率': {'23.75', '25.05', '36.65', '37.58', '41.26', '41.54'},
'定增年度': {'2016.0'},
'增发目的': {'融资收购其他资产', '配套融资'}},
'all_values': {'11.16',
'11.29',
'12.48',
'14.78',
'15.09',
'16.34',
'16.988',
'2016.0',
'21.88',
'22.09',
'23.07',
'23.3004',
'23.75',
'25.05',
'300148.SZ',
'300182.SZ',
'300269.SZ',
'36.65',
'37.58',
'4.09',
'4.69',
'41.26',
'41.54',
'5.48',
'9.91',
'天舟文化',
'捷成股份',
'联建光电',
'融资收购其他资产',
'配套融资'}}
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
token_dict
{'[PAD]': 0,
'[unused1]': 1,
'[unused2]': 2,
'[unused3]': 3,
'[unused4]': 4,
'[unused5]': 5,
'[unused6]': 6,
'[unused7]': 7,
'[unused8]': 8,
'[unused9]': 9,
'[unused10]': 10,
'[unused11]': 11,
'[unused12]': 12,
'[unused13]': 13,
'[unused14]': 14,
'[unused15]': 15,
'[unused16]': 16,
'[unused17]': 17,
'[unused18]': 18,
'[unused19]': 19,
'[unused20]': 20,
'[unused21]': 21,
'[unused22]': 22,
'[unused23]': 23,
'[unused24]': 24,
'[unused25]': 25,
'[unused26]': 26,
'[unused27]': 27,
'[unused28]': 28,
'[unused29]': 29,
'[unused30]': 30,
'[unused31]': 31,
'[unused32]': 32,
'[unused33]': 33,
'[unused34]': 34,
'[unused35]': 35,
'[unused36]': 36,
'[unused37]': 37,
'[unused38]': 38,
'[unused39]': 39,
'[unused40]': 40,
'[unused41]': 41,
'[unused42]': 42,
'[unused43]': 43,
'[unused44]': 44,
'[unused45]': 45,
'[unused46]': 46,
'[unused47]': 47,
'[unused48]': 48,
'[unused49]': 49,
'[unused50]': 50,
'[unused51]': 51,
'[unused52]': 52,
'[unused53]': 53,
'[unused54]': 54,
'[unused55]': 55,
'[unused56]': 56,
'[unused57]': 57,
'[unused58]': 58,
'[unused59]': 59,
'[unused60]': 60,
'[unused61]': 61,
'[unused62]': 62,
'[unused63]': 63,
'[unused64]': 64,
'[unused65]': 65,
'[unused66]': 66,
'[unused67]': 67,
'[unused68]': 68,
'[unused69]': 69,
'[unused70]': 70,
'[unused71]': 71,
'[unused72]': 72,
'[unused73]': 73,
'[unused74]': 74,
'[unused75]': 75,
'[unused76]': 76,
'[unused77]': 77,
'[unused78]': 78,
'[unused79]': 79,
'[unused80]': 80,
'[unused81]': 81,
'[unused82]': 82,
'[unused83]': 83,
'[unused84]': 84,
'[unused85]': 85,
'[unused86]': 86,
'[unused87]': 87,
'[unused88]': 88,
'[unused89]': 89,
'[unused90]': 90,
'[unused91]': 91,
'[unused92]': 92,
'[unused93]': 93,
'[unused94]': 94,
'[unused95]': 95,
'[unused96]': 96,
'[unused97]': 97,
'[unused98]': 98,
'[unused99]': 99,
'[UNK]': 100,
'[CLS]': 101,
'[SEP]': 102,
'[MASK]': 103,
'': 104,
'': 105,
'!': 106,
'"': 107,
'#': 108,
'$': 109,
'%': 110,
'&': 111,
"'": 112,
'(': 113,
')': 114,
'*': 115,
'+': 116,
',': 117,
'-': 118,
'.': 119,
'/': 120,
'0': 121,
'1': 122,
'2': 123,
'3': 124,
'4': 125,
'5': 126,
'6': 127,
'7': 128,
'8': 129,
'9': 130,
':': 131,
';': 132,
'<': 133,
'=': 134,
'>': 135,
'?': 136,
'@': 137,
'[': 138,
'\\': 139,
']': 140,
'^': 141,
'_': 142,
'a': 143,
'b': 144,
'c': 145,
'd': 146,
'e': 147,
'f': 148,
'g': 149,
'h': 150,
'i': 151,
'j': 152,
'k': 153,
'l': 154,
'm': 155,
'n': 156,
'o': 157,
'p': 158,
'q': 159,
'r': 160,
's': 161,
't': 162,
'u': 163,
'v': 164,
'w': 165,
'x': 166,
'y': 167,
'z': 168,
'{': 169,
'|': 170,
'}': 171,
'~': 172,
'£': 173,
'¤': 174,
'¥': 175,
'§': 176,
'©': 177,
'«': 178,
'®': 179,
'°': 180,
'±': 181,
'²': 182,
'³': 183,
'µ': 184,
'·': 185,
'¹': 186,
'º': 187,
'»': 188,
'¼': 189,
'×': 190,
'ß': 191,
'æ': 192,
'÷': 193,
'ø': 194,
'đ': 195,
'ŋ': 196,
'ɔ': 197,
'ə': 198,
'ɡ': 199,
'ʰ': 200,
'ˇ': 201,
'ˈ': 202,
'ˊ': 203,
'ˋ': 204,
'ˍ': 205,
'ː': 206,
'˙': 207,
'˚': 208,
'ˢ': 209,
'α': 210,
'β': 211,
'γ': 212,
'δ': 213,
'ε': 214,
'η': 215,
'θ': 216,
'ι': 217,
'κ': 218,
'λ': 219,
'μ': 220,
'ν': 221,
'ο': 222,
'π': 223,
'ρ': 224,
'ς': 225,
'σ': 226,
'τ': 227,
'υ': 228,
'φ': 229,
'χ': 230,
'ψ': 231,
'ω': 232,
'а': 233,
'б': 234,
'в': 235,
'г': 236,
'д': 237,
'е': 238,
'ж': 239,
'з': 240,
'и': 241,
'к': 242,
'л': 243,
'м': 244,
'н': 245,
'о': 246,
'п': 247,
'р': 248,
'с': 249,
'т': 250,
'у': 251,
'ф': 252,
'х': 253,
'ц': 254,
'ч': 255,
'ш': 256,
'ы': 257,
'ь': 258,
'я': 259,
'і': 260,
'ا': 261,
'ب': 262,
'ة': 263,
'ت': 264,
'د': 265,
'ر': 266,
'س': 267,
'ع': 268,
'ل': 269,
'م': 270,
'ن': 271,
'ه': 272,
'و': 273,
'ي': 274,
'۩': 275,
'ก': 276,
'ง': 277,
'น': 278,
'ม': 279,
'ย': 280,
'ร': 281,
'อ': 282,
'า': 283,
'เ': 284,
'๑': 285,
'་': 286,
'ღ': 287,
'ᄀ': 288,
'ᄁ': 289,
'ᄂ': 290,
'ᄃ': 291,
'ᄅ': 292,
'ᄆ': 293,
'ᄇ': 294,
'ᄈ': 295,
'ᄉ': 296,
'ᄋ': 297,
'ᄌ': 298,
'ᄎ': 299,
'ᄏ': 300,
'ᄐ': 301,
'ᄑ': 302,
'ᄒ': 303,
'ᅡ': 304,
'ᅢ': 305,
'ᅣ': 306,
'ᅥ': 307,
'ᅦ': 308,
'ᅧ': 309,
'ᅨ': 310,
'ᅩ': 311,
'ᅪ': 312,
'ᅬ': 313,
'ᅭ': 314,
'ᅮ': 315,
'ᅯ': 316,
'ᅲ': 317,
'ᅳ': 318,
'ᅴ': 319,
'ᅵ': 320,
'ᆨ': 321,
'ᆫ': 322,
'ᆯ': 323,
'ᆷ': 324,
'ᆸ': 325,
'ᆺ': 326,
'ᆻ': 327,
'ᆼ': 328,
'ᗜ': 329,
'ᵃ': 330,
'ᵉ': 331,
'ᵍ': 332,
'ᵏ': 333,
'ᵐ': 334,
'ᵒ': 335,
'ᵘ': 336,
'‖': 337,
'„': 338,
'†': 339,
'•': 340,
'‥': 341,
'‧': 342,
'': 13503,
'‰': 344,
'′': 345,
'″': 346,
'‹': 347,
'›': 348,
'※': 349,
'‿': 350,
'⁄': 351,
'ⁱ': 352,
'⁺': 353,
'ⁿ': 354,
'₁': 355,
'₂': 356,
'₃': 357,
'₄': 358,
'€': 359,
'℃': 360,
'№': 361,
'™': 362,
'ⅰ': 363,
'ⅱ': 364,
'ⅲ': 365,
'ⅳ': 366,
'ⅴ': 367,
'←': 368,
'↑': 369,
'→': 370,
'↓': 371,
'↔': 372,
'↗': 373,
'↘': 374,
'⇒': 375,
'∀': 376,
'−': 377,
'∕': 378,
'∙': 379,
'√': 380,
'∞': 381,
'∟': 382,
'∠': 383,
'∣': 384,
'∥': 385,
'∩': 386,
'∮': 387,
'∶': 388,
'∼': 389,
'∽': 390,
'≈': 391,
'≒': 392,
'≡': 393,
'≤': 394,
'≥': 395,
'≦': 396,
'≧': 397,
'≪': 398,
'≫': 399,
'⊙': 400,
'⋅': 401,
'⋈': 402,
'⋯': 403,
'⌒': 404,
'①': 405,
'②': 406,
'③': 407,
'④': 408,
'⑤': 409,
'⑥': 410,
'⑦': 411,
'⑧': 412,
'⑨': 413,
'⑩': 414,
'⑴': 415,
'⑵': 416,
'⑶': 417,
'⑷': 418,
'⑸': 419,
'⒈': 420,
'⒉': 421,
'⒊': 422,
'⒋': 423,
'ⓒ': 424,
'ⓔ': 425,
'ⓘ': 426,
'─': 427,
'━': 428,
'│': 429,
'┃': 430,
'┅': 431,
'┆': 432,
'┊': 433,
'┌': 434,
'└': 435,
'├': 436,
'┣': 437,
'═': 438,
'║': 439,
'╚': 440,
'╞': 441,
'╠': 442,
'╭': 443,
'╮': 444,
'╯': 445,
'╰': 446,
'╱': 447,
'╳': 448,
'▂': 449,
'▃': 450,
'▅': 451,
'▇': 452,
'█': 453,
'▉': 454,
'▋': 455,
'▌': 456,
'▍': 457,
'▎': 458,
'■': 459,
'□': 460,
'▪': 461,
'▫': 462,
'▬': 463,
'▲': 464,
'△': 465,
'▶': 466,
'►': 467,
'▼': 468,
'▽': 469,
'◆': 470,
'◇': 471,
'○': 472,
'◎': 473,
'●': 474,
'◕': 475,
'◠': 476,
'◢': 477,
'◤': 478,
'☀': 479,
'★': 480,
'☆': 481,
'☕': 482,
'☞': 483,
'☺': 484,
'☼': 485,
'♀': 486,
'♂': 487,
'♠': 488,
'♡': 489,
'♣': 490,
'♥': 491,
'♦': 492,
'♪': 493,
'♫': 494,
'♬': 495,
'✈': 496,
'✔': 497,
'✕': 498,
'✖': 499,
'✦': 500,
'✨': 501,
'✪': 502,
'✰': 503,
'✿': 504,
'❀': 505,
'❤': 506,
'➜': 507,
'➤': 508,
'⦿': 509,
'、': 510,
'。': 511,
'〃': 512,
'々': 513,
'〇': 514,
'〈': 515,
'〉': 516,
'《': 517,
'》': 518,
'「': 519,
'」': 520,
'『': 521,
'』': 522,
'【': 523,
'】': 524,
'〓': 525,
'〔': 526,
'〕': 527,
'〖': 528,
'〗': 529,
'〜': 530,
'〝': 531,
'〞': 532,
'ぁ': 533,
'あ': 534,
'ぃ': 535,
'い': 536,
'う': 537,
'ぇ': 538,
'え': 539,
'お': 540,
'か': 541,
'き': 542,
'く': 543,
'け': 544,
'こ': 545,
'さ': 546,
'し': 547,
'す': 548,
'せ': 549,
'そ': 550,
'た': 551,
'ち': 552,
'っ': 553,
'つ': 554,
'て': 555,
'と': 556,
'な': 557,
'に': 558,
'ぬ': 559,
'ね': 560,
'の': 561,
'は': 562,
'ひ': 563,
'ふ': 564,
'へ': 565,
'ほ': 566,
'ま': 567,
'み': 568,
'む': 569,
'め': 570,
'も': 571,
'ゃ': 572,
'や': 573,
'ゅ': 574,
'ゆ': 575,
'ょ': 576,
'よ': 577,
'ら': 578,
'り': 579,
'る': 580,
'れ': 581,
'ろ': 582,
'わ': 583,
'を': 584,
'ん': 585,
'゜': 586,
'ゝ': 587,
'ァ': 588,
'ア': 589,
'ィ': 590,
'イ': 591,
'ゥ': 592,
'ウ': 593,
'ェ': 594,
'エ': 595,
'ォ': 596,
'オ': 597,
'カ': 598,
'キ': 599,
'ク': 600,
'ケ': 601,
'コ': 602,
'サ': 603,
'シ': 604,
'ス': 605,
'セ': 606,
'ソ': 607,
'タ': 608,
'チ': 609,
'ッ': 610,
'ツ': 611,
'テ': 612,
'ト': 613,
'ナ': 614,
'ニ': 615,
'ヌ': 616,
'ネ': 617,
'ノ': 618,
'ハ': 619,
'ヒ': 620,
'フ': 621,
'ヘ': 622,
'ホ': 623,
'マ': 624,
'ミ': 625,
'ム': 626,
'メ': 627,
'モ': 628,
'ャ': 629,
'ヤ': 630,
'ュ': 631,
'ユ': 632,
'ョ': 633,
'ヨ': 634,
'ラ': 635,
'リ': 636,
'ル': 637,
'レ': 638,
'ロ': 639,
'ワ': 640,
'ヲ': 641,
'ン': 642,
'ヶ': 643,
'・': 644,
'ー': 645,
'ヽ': 646,
'ㄅ': 647,
'ㄆ': 648,
'ㄇ': 649,
'ㄉ': 650,
'ㄋ': 651,
'ㄌ': 652,
'ㄍ': 653,
'ㄎ': 654,
'ㄏ': 655,
'ㄒ': 656,
'ㄚ': 657,
'ㄛ': 658,
'ㄞ': 659,
'ㄟ': 660,
'ㄢ': 661,
'ㄤ': 662,
'ㄥ': 663,
'ㄧ': 664,
'ㄨ': 665,
'ㆍ': 666,
'㈦': 667,
'㊣': 668,
'㎡': 669,
'㗎': 670,
'一': 671,
'丁': 672,
'七': 673,
'万': 674,
'丈': 675,
'三': 676,
'上': 677,
'下': 678,
'不': 679,
'与': 680,
'丐': 681,
'丑': 682,
'专': 683,
'且': 684,
'丕': 685,
'世': 686,
'丘': 687,
'丙': 688,
'业': 689,
'丛': 690,
'东': 691,
'丝': 692,
'丞': 693,
'丟': 694,
'両': 695,
'丢': 696,
'两': 697,
'严': 698,
'並': 699,
'丧': 700,
'丨': 701,
'个': 702,
'丫': 703,
'中': 704,
'丰': 705,
'串': 706,
'临': 707,
'丶': 708,
'丸': 709,
'丹': 710,
'为': 711,
'主': 712,
'丼': 713,
'丽': 714,
'举': 715,
'丿': 716,
'乂': 717,
'乃': 718,
'久': 719,
'么': 720,
'义': 721,
'之': 722,
'乌': 723,
'乍': 724,
'乎': 725,
'乏': 726,
'乐': 727,
'乒': 728,
'乓': 729,
'乔': 730,
'乖': 731,
'乗': 732,
'乘': 733,
'乙': 734,
'乜': 735,
'九': 736,
'乞': 737,
'也': 738,
'习': 739,
'乡': 740,
'书': 741,
'乩': 742,
'买': 743,
'乱': 744,
'乳': 745,
'乾': 746,
'亀': 747,
'亂': 748,
'了': 749,
'予': 750,
'争': 751,
'事': 752,
'二': 753,
'于': 754,
'亏': 755,
'云': 756,
'互': 757,
'五': 758,
'井': 759,
'亘': 760,
'亙': 761,
'亚': 762,
'些': 763,
'亜': 764,
'亞': 765,
'亟': 766,
'亡': 767,
'亢': 768,
'交': 769,
'亥': 770,
'亦': 771,
'产': 772,
'亨': 773,
'亩': 774,
'享': 775,
'京': 776,
'亭': 777,
'亮': 778,
'亲': 779,
'亳': 780,
'亵': 781,
'人': 782,
'亿': 783,
'什': 784,
'仁': 785,
'仃': 786,
'仄': 787,
'仅': 788,
'仆': 789,
'仇': 790,
'今': 791,
'介': 792,
'仍': 793,
'从': 794,
'仏': 795,
'仑': 796,
'仓': 797,
'仔': 798,
'仕': 799,
'他': 800,
'仗': 801,
'付': 802,
'仙': 803,
'仝': 804,
'仞': 805,
'仟': 806,
'代': 807,
'令': 808,
'以': 809,
'仨': 810,
'仪': 811,
'们': 812,
'仮': 813,
'仰': 814,
'仲': 815,
'件': 816,
'价': 817,
'任': 818,
'份': 819,
'仿': 820,
'企': 821,
'伉': 822,
'伊': 823,
'伍': 824,
'伎': 825,
'伏': 826,
'伐': 827,
'休': 828,
'伕': 829,
'众': 830,
'优': 831,
'伙': 832,
'会': 833,
'伝': 834,
'伞': 835,
'伟': 836,
'传': 837,
'伢': 838,
'伤': 839,
'伦': 840,
'伪': 841,
'伫': 842,
'伯': 843,
'估': 844,
'伴': 845,
'伶': 846,
'伸': 847,
'伺': 848,
'似': 849,
'伽': 850,
'佃': 851,
'但': 852,
'佇': 853,
'佈': 854,
'位': 855,
'低': 856,
'住': 857,
'佐': 858,
'佑': 859,
'体': 860,
'佔': 861,
'何': 862,
'佗': 863,
'佘': 864,
'余': 865,
'佚': 866,
'佛': 867,
'作': 868,
'佝': 869,
'佞': 870,
'佟': 871,
'你': 872,
'佢': 873,
'佣': 874,
'佤': 875,
'佥': 876,
'佩': 877,
'佬': 878,
'佯': 879,
'佰': 880,
'佳': 881,
'併': 882,
'佶': 883,
'佻': 884,
'佼': 885,
'使': 886,
'侃': 887,
'侄': 888,
'來': 889,
'侈': 890,
'例': 891,
'侍': 892,
'侏': 893,
'侑': 894,
'侖': 895,
'侗': 896,
'供': 897,
'依': 898,
'侠': 899,
'価': 900,
'侣': 901,
'侥': 902,
'侦': 903,
'侧': 904,
'侨': 905,
'侬': 906,
'侮': 907,
'侯': 908,
'侵': 909,
'侶': 910,
'侷': 911,
'便': 912,
'係': 913,
'促': 914,
'俄': 915,
'俊': 916,
'俎': 917,
'俏': 918,
'俐': 919,
'俑': 920,
'俗': 921,
'俘': 922,
'俚': 923,
'保': 924,
'俞': 925,
'俟': 926,
'俠': 927,
'信': 928,
'俨': 929,
'俩': 930,
'俪': 931,
'俬': 932,
'俭': 933,
'修': 934,
'俯': 935,
'俱': 936,
'俳': 937,
'俸': 938,
'俺': 939,
'俾': 940,
'倆': 941,
'倉': 942,
'個': 943,
'倌': 944,
'倍': 945,
'倏': 946,
'們': 947,
'倒': 948,
'倔': 949,
'倖': 950,
'倘': 951,
'候': 952,
'倚': 953,
'倜': 954,
'借': 955,
'倡': 956,
'値': 957,
'倦': 958,
'倩': 959,
'倪': 960,
'倫': 961,
'倬': 962,
'倭': 963,
'倶': 964,
'债': 965,
'值': 966,
'倾': 967,
'偃': 968,
'假': 969,
'偈': 970,
'偉': 971,
'偌': 972,
'偎': 973,
'偏': 974,
'偕': 975,
'做': 976,
'停': 977,
'健': 978,
'側': 979,
'偵': 980,
'偶': 981,
'偷': 982,
'偻': 983,
'偽': 984,
'偿': 985,
'傀': 986,
'傅': 987,
'傍': 988,
'傑': 989,
'傘': 990,
'備': 991,
'傚': 992,
'傢': 993,
'傣': 994,
'傥': 995,
'储': 996,
'傩': 997,
'催': 998,
'傭': 999,
...}
class OurTokenizer(Tokenizer):
def _tokenize(self, text):
R = []
for c in text:
if c in self._token_dict:
R.append(c)
elif self._is_space(c):
R.append('[unused1]')
else:
R.append('[UNK]')
return R
tokenizer = OurTokenizer(token_dict)
tokenizer
<__main__.OurTokenizer at 0x2430e637908>
def seq_padding(X, padding=0, maxlen=None):
if maxlen is None:
L = [len(x) for x in X]
ML = max(L)
else:
ML = maxlen
return np.array([
np.concatenate([x[:ML], [padding] * (ML - len(x))]) if len(x[:ML]) < ML else x for x in X
])
def most_similar(s, slist):
"""从词表中找最相近的词(当无法全匹配的时候)
"""
if len(slist) == 0:
return s
scores = [editdistance.eval(s, t) for t in slist]
return slist[np.argmin(scores)]
def most_similar_2(w, s):
"""从句子s中找与w最相近的片段,
借助分词工具和ngram的方式尽量精确地确定边界。
"""
sw = jieba.lcut(s)
sl = list(sw)
sl.extend([''.join(i) for i in zip(sw, sw[1:])])
sl.extend([''.join(i) for i in zip(sw, sw[1:], sw[2:])])
return most_similar(w, sl)
d=train_data[0]
x1, x2 = tokenizer.encode('二零一九年',"我是傻子")
print(x1)
print(x2)
print(len(x1))
print(len(x2))
[101, 753, 7439, 671, 736, 2399, 102, 2769, 3221, 1004, 2094, 102]
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
12
12
class data_generator:
def __init__(self, data, tables, batch_size=32):
self.data = data
self.tables = tables
self.batch_size = batch_size
self.steps = len(self.data) // self.batch_size
if len(self.data) % self.batch_size != 0:
self.steps += 1
def __len__(self):
return self.steps
def __iter__(self):
while True:
idxs = list(range(len(self.data)))
np.random.shuffle(idxs)
X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
for i in idxs:
d = self.data[i]
t = self.tables[d['table_id']]['headers']
x1, x2 = tokenizer.encode(d['question'])
xm = [0] + [1] * len(d['question']) + [0]
h = []
for j in t:
_x1, _x2 = tokenizer.encode(j)
h.append(len(x1))
x1.extend(_x1)
x2.extend(_x2)
hm = [1] * len(h)
sel = []
for j in range(len(h)):
if j in d['sql']['sel']:
j = d['sql']['sel'].index(j)
sel.append(d['sql']['agg'][j])
else:
sel.append(num_agg - 1)
conn = [d['sql']['cond_conn_op']]
csel = np.zeros(len(d['question']) + 2, dtype='int32')
cop = np.zeros(len(d['question']) + 2, dtype='int32') + num_op - 1
for j in d['sql']['conds']:
if j[2] not in d['question']:
j[2] = most_similar_2(j[2], d['question'])
if j[2] not in d['question']:
continue
k = d['question'].index(j[2])
csel[k + 1: k + 1 + len(j[2])] = j[0]
cop[k + 1: k + 1 + len(j[2])] = j[1]
if len(x1) > maxlen:
continue
X1.append(x1)
X2.append(x2)
XM.append(xm)
H.append(h)
HM.append(hm)
SEL.append(sel)
CONN.append(conn)
CSEL.append(csel)
COP.append(cop)
if len(X1) == self.batch_size:
X1 = seq_padding(X1)
X2 = seq_padding(X2)
XM = seq_padding(XM, maxlen=X1.shape[1])
H = seq_padding(H)
HM = seq_padding(HM)
SEL = seq_padding(SEL)
CONN = seq_padding(CONN)
CSEL = seq_padding(CSEL, maxlen=X1.shape[1])
COP = seq_padding(COP, maxlen=X1.shape[1])
yield [X1, X2, XM, H, HM, SEL, CONN, CSEL, COP], None
X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
def seq_gather(x):
"""seq是[None, seq_len, s_size]的格式,
idxs是[None, n]的格式,在seq的第i个序列中选出第idxs[i]个向量,
最终输出[None, n, s_size]的向量。
"""
seq, idxs = x
idxs = K.cast(idxs, 'int32')
return K.tf.batch_gather(seq, idxs)
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,), dtype='int32')
x2_in = Input(shape=(None,))
xm_in = Input(shape=(None,))
h_in = Input(shape=(None,), dtype='int32')
hm_in = Input(shape=(None,))
sel_in = Input(shape=(None,), dtype='int32')
conn_in = Input(shape=(1,), dtype='int32')
csel_in = Input(shape=(None,), dtype='int32')
cop_in = Input(shape=(None,), dtype='int32')
x1, x2, xm, h, hm, sel, conn, csel, cop = (
x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in
)
hm = Lambda(lambda x: K.expand_dims(x, 1))(hm)
x = bert_model([x1_in, x2_in])
x4conn = Lambda(lambda x: x[:, 0])(x)
pconn = Dense(num_cond_conn_op, activation='softmax')(x4conn)
x4h = Lambda(seq_gather)([x, h])
psel = Dense(num_agg, activation='softmax')(x4h)
pcop = Dense(num_op, activation='softmax')(x)
x = Lambda(lambda x: K.expand_dims(x, 2))(x)
x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h)
pcsel_1 = Dense(256)(x)
pcsel_2 = Dense(256)(x4h)
pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2])
pcsel = Activation('tanh')(pcsel)
pcsel = Dense(1)(pcsel)
pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm])
pcsel = Activation('softmax')(pcsel)
model = Model(
[x1_in, x2_in, h_in, hm_in],
[psel, pconn, pcop, pcsel]
)
train_model = Model(
[x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in],
[psel, pconn, pcop, pcsel]
)
xm = xm
hm = hm[:, 0]
cm = K.cast(K.not_equal(cop, num_op - 1), 'float32')
psel_loss = K.sparse_categorical_crossentropy(sel_in, psel)
psel_loss = K.sum(psel_loss * hm) / K.sum(hm)
pconn_loss = K.sparse_categorical_crossentropy(conn_in, pconn)
pconn_loss = K.mean(pconn_loss)
pcop_loss = K.sparse_categorical_crossentropy(cop_in, pcop)
pcop_loss = K.sum(pcop_loss * xm) / K.sum(xm)
pcsel_loss = K.sparse_categorical_crossentropy(csel_in, pcsel)
pcsel_loss = K.sum(pcsel_loss * xm * cm) / K.sum(xm * cm)
loss = psel_loss + pconn_loss + pcop_loss + pcsel_loss
train_model.add_loss(loss)
train_model.compile(optimizer=Adam(learning_rate))
train_model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, None) 0
__________________________________________________________________________________________________
input_2 (InputLayer) (None, None) 0
__________________________________________________________________________________________________
model_2 (Model) (None, None, 768) 101677056 input_1[0][0]
input_2[0][0]
__________________________________________________________________________________________________
input_4 (InputLayer) (None, None) 0
__________________________________________________________________________________________________
lambda_3 (Lambda) (None, None, 768) 0 model_2[1][0]
input_4[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda) (None, None, 1, 768) 0 model_2[1][0]
__________________________________________________________________________________________________
lambda_5 (Lambda) (None, 1, None, 768) 0 lambda_3[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, None, 1, 256) 196864 lambda_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 1, None, 256) 196864 lambda_5[0][0]
__________________________________________________________________________________________________
lambda_6 (Lambda) (None, None, None, 2 0 dense_4[0][0]
dense_5[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, None, None, 2 0 lambda_6[0][0]
__________________________________________________________________________________________________
input_5 (InputLayer) (None, None) 0
__________________________________________________________________________________________________
dense_6 (Dense) (None, None, None, 1 257 activation_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 1, None) 0 input_5[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 768) 0 model_2[1][0]
__________________________________________________________________________________________________
lambda_7 (Lambda) (None, None, None) 0 dense_6[0][0]
lambda_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, None, 7) 5383 lambda_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 3) 2307 lambda_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, None, 5) 3845 model_2[1][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, None, None) 0 lambda_7[0][0]
==================================================================================================
Total params: 102,082,576
Trainable params: 102,082,576
Non-trainable params: 0
__________________________________________________________________________________________________
def nl2sql(question, table):
"""输入question和headers,转SQL
"""
x1, x2 = tokenizer.encode(question)
h = []
for i in table['headers']:
_x1, _x2 = tokenizer.encode(i)
h.append(len(x1))
x1.extend(_x1)
x2.extend(_x2)
hm = [1] * len(h)
psel, pconn, pcop, pcsel = model.predict([
np.array([x1]),
np.array([x2]),
np.array([h]),
np.array([hm])
])
R = {'agg': [], 'sel': []}
for i, j in enumerate(psel[0].argmax(1)):
if j != num_agg - 1:
R['sel'].append(i)
R['agg'].append(j)
conds = []
v_op = -1
for i, j in enumerate(pcop[0, :len(question)+1].argmax(1)):
if j != num_op - 1:
if v_op != j:
if v_op != -1:
v_end = v_start + len(v_str)
csel = pcsel[0][v_start: v_end].mean(0).argmax()
conds.append((csel, v_op, v_str))
v_start = i
v_op = j
v_str = question[i - 1]
else:
v_str += question[i - 1]
elif v_op != -1:
v_end = v_start + len(v_str)
csel = pcsel[0][v_start: v_end].mean(0).argmax()
conds.append((csel, v_op, v_str))
v_op = -1
R['conds'] = set()
for i, j, k in conds:
if re.findall('[^\d\.]', k):
j = 2
if j == 2:
if k not in table['all_values']:
k = most_similar(k, list(table['all_values']))
h = table['headers'][i]
if k not in table['content'][h]:
for r, v in table['content'].items():
if k in v:
i = table['header2id'][r]
break
R['conds'].add((i, j, k))
R['conds'] = list(R['conds'])
if len(R['conds']) <= 1:
R['cond_conn_op'] = 0
else:
R['cond_conn_op'] = 1 + pconn[0, 1:].argmax()
return R
def is_equal(R1, R2):
"""判断两个SQL字典是否全匹配
"""
return (R1['cond_conn_op'] == R2['cond_conn_op']) &\
(set(zip(R1['sel'], R1['agg'])) == set(zip(R2['sel'], R2['agg']))) &\
(set([tuple(i) for i in R1['conds']]) == set([tuple(i) for i in R2['conds']]))
def evaluate(data, tables):
right = 0.
pbar = tqdm()
F = open('evaluate_pred.json', 'w')
for i, d in enumerate(data):
question = d['question']
table = tables[d['table_id']]
R = nl2sql(question, table)
right += float(is_equal(R, d['sql']))
pbar.update(1)
pbar.set_description('< acc: %.5f >' % (right / (i + 1)))
d['sql_pred'] = R
s = json.dumps(d, ensure_ascii=False, indent=4)
F.write(s.encode('utf-8') + '\n')
F.close()
pbar.close()
return right / len(data)
def test(data, tables, outfile='result.json'):
pbar = tqdm()
F = open(outfile, 'w')
for i, d in enumerate(data):
question = d['question']
table = tables[d['table_id']]
R = nl2sql(question, table)
pbar.update(1)
s = json.dumps(R, ensure_ascii=False)
F.write(s.encode('utf-8') + '\n')
F.close()
pbar.close()
class Evaluate(Callback):
def __init__(self):
self.accs = []
self.best = 0.
self.passed = 0
self.stage = 0
def on_batch_begin(self, batch, logs=None):
"""第一个epoch用来warmup,第二个epoch把学习率降到最低
"""
if self.passed < self.params['steps']:
lr = (self.passed + 1.) / self.params['steps'] * learning_rate
K.set_value(self.model.optimizer.lr, lr)
self.passed += 1
elif self.params['steps'] <= self.passed < self.params['steps'] * 2:
lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate)
lr += min_learning_rate
K.set_value(self.model.optimizer.lr, lr)
self.passed += 1
def on_epoch_end(self, epoch, logs=None):
acc = self.evaluate()
self.accs.append(acc)
if acc > self.best:
self.best = acc
train_model.save_weights('best_model.weights')
print ('acc: %.5f, best acc: %.5f\n' % (acc, self.best))
def evaluate(self):
return evaluate(valid_data, valid_tables)
train_D = data_generator(train_data, train_tables)
evaluator = Evaluate()
if __name__ == '__main__':
train_model.fit_generator(
train_D.__iter__(),
steps_per_epoch=len(train_D),
epochs=15,
callbacks=[evaluator]
)
else:
train_model.load_weights('best_model.weights')