import os,json,time,re,copy,random
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
def load_raw_data(filename):
data=[]
with open(filename,encoding='utf-8') as f:
lines=f.readlines()#lines是一个列表,每一个元素是文件中的一行,文件中的每7行组成一条训练样本
json_string=''
#每7行是一个样本
for line_id,line_str in enumerate(lines):
json_string+=line_str
line_id+=1
if line_id%7==0:
example=json.loads(json_string)#json.loads可以将字典形式的字符串转换成一个字典
#example是一个字典,key值有'id','original_text','segmented_text','equation','ans'
if '千米/小时' in example['equation']:
example['equation']=example['equation'][:-5]#有些等式中含有(千米/小时)这个单位,把这个单位去掉
data.append(example)
json_string=''
return data
这个函数是数据处理部分的核心函数
def transfer_num(data):
'''
将数据集中的每一个样本对应的文本问题中的数字替换成NUM
'''
#正则表达式中: +表示出现一次或多次,*表示出现零次或多次
#\d*\(\d+/\d+\)\d* 这个正则是为了匹配行如 (3/5)、2(3/5)、2(3/5)12 这类的数字(也就是带有括号的分数)
#\d+\.\d+%? 这个正则是为了匹配行如 3.5 3.5% 这类的数字(也就是小数或者带有百分号的小数)
#\d+%? 这个正则是为了匹配整数以及3%这类的带有百分号的整数
pattern=re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")#pattern用来匹配问题文本中的所有数字
pairs,generate_nums,generate_nums_dict=[],[],{
}
copy_nums=0#copy_nums用来记录数据集中的所有问题中,哪一个问题中出现的数字次数最多,copy_nums用来记录这个次数
#copy_nums的数值影响着decoder端的词汇空间
for example in data:
#example行如:{'id': '5001','original_text': '某电视机厂原来每天生产116台电视机,现在每天生产的台数是原来的12倍,现在每天能生产多少台电视机?',
#'segmented_text': '某 电视机厂 原来 每天 生产 116 台 电视机 , 现在 每天 生产 的 台数 是 原来 的 12 倍 , 现在 每天 能 生产 多少 台 电视机 ?',
#'equation': 'x=116*12','ans': '1392'}
idx=example['id']
nums=[]#nums用来记录根据pattern匹配出的问题文本中的所有数字
input_seq=[]#input_seq用来将问题文本中的所有数字替换成NUM
seg=example['segmented_text'].strip().split(' ')
#seg行如: ['某', '电视机厂', '原来', '每天', '生产', '116', '台', '电视机', ',', '现在', '每天', '生产', '的', '台数', '是', '原来', '的', '12', '倍', ',', '现在', '每天', '能', '生产', '多少', '台', '电视机', '?']
equation=example['equation'][2:]
#equations的形式行如: x=(25+14)/(1-(1/5)-(1/5));x=(11-1)*2;x=116*12等
for token in seg:
pos=re.search(pattern=pattern,string=token)#如果token是数字,那么pos返回的不是None
if pos and pos.start()==0:
nums.append(token[pos.start():pos.end()])
input_seq.append('NUM')#将所有数字替换成NUM
if pos.end()<len(token):
#说明此时的token不仅仅含有数字,eg: 116千克,那么input_seq中要添加千克这个单词
input_seq.append(token[pos.end():])
elif token!='':
#此时的token中没有数字
input_seq.append(token)
if copy_nums<len(nums):
copy_nums=len(nums)#copy_nums用来记录所有问题中出现数字次数最多的那个问题出现的数字的次数
nums_fraction=[]#nums_fraction用来记录这个问题中出现的行如(2/5)这种带括号的分数数字
for num in nums:
if re.search('\d*\(\d+/\d+\)\d*',num):
nums_fraction.append(num)#num行如 5(2/5) (2/5) 5(2/5)5 这种,
nums_fraction=sorted(nums_fraction,key=lambda x:len(x),reverse=True)#将nums_fraction中的带括号的分数数字按照长度排序
#实验表明,排序或者不排序一点关系没有
def seg_and_tag(equation):
'''
seg_and_tag函数的作用是将equation,也就是表达式中的字符分割开,例如:equation='(25+14)/(1-(1/5)-(1/5))'
那么返回的表达式应该是['(', '25', '+', '14', ')', '/', '(', '1', '-', '(1/5)', '-', '(1/5)', ')']
同时要将各个数字替换成Ni,i代表这个数字在问题文本中出现的顺序
这也是为什么前面要用nums_fraction专门保存带括号的分数,这样才能使得整个括号和分数看成一个整体
'''
res=[]
for num in nums_fraction:
#如果nums_fraction是空列表,也就是说当前问题没有带括号的分数,那么这个for循环自然不会执行
if num in equation:
#从equation中找到这个带括号的分数的位置
p_start=equation.find(num)
p_end=p_start+len(num)
if p_start>0:
#以上面的equation为例子,显然此时num等于(1/5),所以p_start>0,此时我们需要处理(25+14)/(1-
res+=seg_and_tag(equation[:p_start])
if nums.count(num)==1:
#也就是说这个数字仅在问题文本中出现过一次,那么此时就可以用Ni代替这个数字,
#i表示的是这个数字在文本中出现的顺序
res.append('N'+str(nums.index(num)))
else:
#说明这个数字在问题文本中出现了多次,那么此时直接记录这个数字,而不用Ni替代
res.append(num)
if p_end<len(equation):
res+=seg_and_tag(equation[p_end:])#递归右边的部分
return res
#现在已经将这类括号带分数的数字处理完毕,接下来处理整数、小数、百分数
number_position=re.search(pattern='\d+\.\d+%?|\d+%?',string=equation)
if number_position:
p_start=number_position.start()
p_end=number_position.end()
if p_start>0:
#类似的,递归左边
res+=seg_and_tag(equation[:p_start])
number=equation[p_start:p_end]
if nums.count(number)==1:
res.append('N'+str(nums.index(number)))
else:
res.append(number)
if p_end<len(equation):
res+=seg_and_tag(equation[p_end:])
return res
#上面的代码是用来处理数字的,如:带有括号的分数、小数、整数、百分数等
#下面的for循环处理equation中的 括号和+-/*
for rest_op in equation:
#rest_op要么是括号(),要么是+-/*
res.append(rest_op)
return res
output_seq=seg_and_tag(equation=equation)#output_seq就是decoder端要生成的表达式标签
for token in output_seq:
if token[0].isdigit() and token not in generate_nums and token not in nums:
#说明此时这是一个数字,并且这个数字没有出现在问题中,这类数字包括1或者3.14这种常数
generate_nums.append(token)
generate_nums_dict[token]=1
if token in generate_nums and token not in nums:
generate_nums_dict[token]+=1
num_pos=[]#num_pos用来记录每一个数字的位置将equation
for i,j in enumerate(input_seq):
#input_seq是将问题中的所有数字替换成NUM后的变量
if j=='NUM':
num_pos.append(i)
assert len(nums)==len(num_pos)
#nums记录的是每一个数字,num_pos记录的是每一个数字的位置
pairs.append((idx,input_seq,output_seq,nums,num_pos))
#结束for循环后,我们就已经处理了所有的问题,接下来统计数据集中频繁出现的常数
temp_g=[]#用来记录数据集中频繁出现的常数,比如3.14
for g in generate_nums:
if generate_nums_dict[g]>=5:
temp_g.append(g)
return pairs,temp_g,copy_nums
这里用到了哈工大的pyltp工具包,我们通过两幅图来看
需要注意的是Root默认占用0,所以其它单词的索引是需要id-1的,这也是为什么源码中有arc.head-1这行代码,不过由于版本问题,此时的arc是一个元祖tuple,不过含义是一样的。
也就是说postagger用来标注每一个单词的词性(名词、动词等),parser用来提取整个句子中各个单词的依存句法关系,关于上面的具体的细节以及ATT,SBV,WP这都是什么玩意,不在详细介绍。
关于原理,请参考中缀表达式转后缀表达式
首先设置两个栈,操作数栈和运算符栈
def from_infix_to_prefix(expression):
operator_stack=[]#运算符栈
operand_stack=[]#操作数栈
operator_priority={
'+':0,'-':0,'*':1,'/':1,'^':2}
expression=deepcopy(expression)#deepcopy是深拷贝
expression.reverse()#转前缀的过程是从右至左扫描
for e in expression:
if e in [')',']']:
#当遇到右括号时,直接进栈
operator_stack.append(e)
elif e =='(':
#弹出栈中的运算符,直到遇到)为止
temp=operator_stack.pop()
while temp!=')':
operand_stack.append(temp)
temp=operator_stack.pop()
elif e=='[':
#弹出栈中的运算符,直到遇到]为止
temp=operator_stack.pop()
while temp!=']':
operand_stack.append(temp)
temp=operator_stack.pop()
elif e in operator_priority:
#此时是运算符,需要比较优先级,当栈顶运算符的优先级大于e的优先级时,就一直弹栈
#不过需要注意的是,如果栈顶是右括号,那么就不能再弹了,因为右括号要等到左括号来了才能弹栈
while len(operator_stack)>0 and operator_stack[-1] not in [')',']'] and operator_priority[e]<operator_priority[operator_stack[-1]]:
operand_stack.append(operator_stack.pop())
operator_stack.append(e)
else:
#说明此时的e是操作数
operand_stack.append(e)
#将运算符栈中的剩余运算符全部弹出到操作数栈中
while len(operator_stack)>0:
operand_stack.append(operator_stack.pop())
operand_stack.reverse()
return operand_stack
思路是一样的,只不过有几个不同点:
def from_infix_to_postfix(expression):
operator_stack=[]
operand_stack=[]
expression=deepcopy(expression)
operator_priority={
'+':0,'-':0,'*':1,'/':1,'^':2}
for e in expression:
if e in ['(','[']:
operator_stack.append(e)
elif e ==')':
temp=operator_stack.pop()
while temp!='(':
operand_stack.append(temp)
temp=operator_stack.pop()
elif e ==']':
temp=operator_stack.pop()
while temp!='[':
operand_stack.append(temp)
temp=operator_stack.pop()
elif e in operator_priority:
while len(operator_stack)>0 and operator_stack[-1] not in ['(','['] and operator_priority[e]<operator_priority[operator_stack[-1]]:
operand_stack.append(operator_stack.pop())
operator_stack.append(e)
else:
operand_stack.append(e)
while len(operator_stack)>0:
operand_stack.append(operator_stack.pop())
return operand_stack
def generate_train_test(math23k_file):
data=load_raw_data(math23k_file)#data的每一个元素是一个dict,字段有:id,original_text,segmented_text,equation,ans
pairs,generate_nums,copy_nums=transfer_num(data)
#pairs是将data的每一个数据里面的segmented_text中的数字转换成NUM,将equation中的数字转换成Ni,其中i
#代表这个数字在问题中出现的顺序,pairs还有两个元素,分别记录问题对应的所有数字和数字的位置
pre_temp_pairs=[]
for p in pairs:
#p[0]是id,p[1]是行如['新世纪', '百货', '开展', '“', '庆', 'NUM', '一', '”', '促销', '活动', ', '再', '降', 'NUM', '?'],
#这样的问题
postags=postagger.post(p[1])#也就是标注问题中的每一个单词的词性
arcs=parser.parse(p[1],postags)#提取整个句子的句法
parse_tree=[arc[0]-1 for arc in arcs]#其中arc是一个元祖(id,relation),id代表的就是当前这个单词与哪一个单词有关联,
#id表示的就是那个单词在整个句子中的索引,但是由于ROOT这个单词默认占据0,所以单词的实际位置需要-1
#relation表示的就是句法关系
pre_temp_pairs.append((p[0],p[1],postags,parse_tree,
from_infix_to_prefix(p[2]),from_infix_to_postfix(p[2]),p[3],p[4]))
#其中p[3]和p[4]分别是nums和nums_pos,也就是这个问题中所有的数字和数字的位置
#p[2]就是中缀表达式,现在已经转换成前缀和后缀了
pairs=pre_temp_pairs
#接下来构造5折交叉验证的数据集
fold_size=int(len(pairs)*0.2)#fold_size也就是每一折的测试集合大小,在math23k上约等于4632
fold_pairs=[]
for split_fold in range(4):
fold_start=fold_size*split_fold
fold_end=fold_size*(split_fold+1)
fold_pairs.append(pairs[fold_start:fold_end])
#split_fold==0,1,2,3
#fold_pairs==[pairs[0:4632],pairs[4632:9264],pairs[9264:13896],pairs[13896:18528]]
fold_pairs.append(pairs[fold_size*4:])#fold_pairs==[pairs[0:4632],pairs[4632:9264],pairs[9264:13896],pairs[13896:18528],pairs[18528:23162]]
for fold in range(5):
pairs_tested=[]
pairs_trained=[]
for fold_t in range(5):
if fold_t==fold:
#当fold==0时,就用fold_pairs[0]作为测试集,其它四个作为训练集
pairs_tested+=fold_pairs[fold_t]
else:
pairs_trained+=fold_pairs[fold_t]
with open("data/train"+str(fold)+".json",'w') as f:
json.dump(pairs_trained,f,ensure_ascii=False,indent=4)
with open("data/test"+str(fold)+".json","w") as f:
json.dump(pairs_tested,f,ensure_ascii=False,indent=4)
train_example=pairs_trained[10]
print("example id : ",train_example[0])
print("example input seq : ",train_example[1])
print("example question pos(pos指的是词性) : ",train_example[2])
print("example syntatic parser(句法分析) : ",train_example[3])
print("example prefix expression : ",train_example[4])
print("example postfix expression : ",train_example[5])
print("example question nums : ",train_example[6])
print("example question nums_pos : ",train_example[7])
我们已经清楚了pairs_trained中每一个数据的结构
PAD_token=0#默认pad位置用0填充
class Lang:
def __init__(self):
self.word2index={
}#词到id的转换字典
self.word2count={
}#词到词频的转换字典
self.index2word=[]
self.n_words=0
self.num_start=0
def add_sen_to_vocab(self,sentence):
#传进来的sentence有多种形式 第一种是问题文本,行如:['要', '修', '一段', '长', 'NUM', '千米', '的', '路', ',', '第一天', '修', '了', 'NUM', '千米', ',', '第', '二', '天', '修', '了', '余下', '的', 'NUM', ',', '还', '剩下', '多少', '千米', '没有', '修', '完', '?']
#第二种是句子的标注词性,行如['v', 'v', 'm', 'a', 'ws', 'q', 'u', 'n', 'wp', 'nt', 'v', 'u', 'ws', 'q', 'wp', 'm', 'm', 'q', 'v', 'u', 'v', 'u', 'ws', 'wp', 'd', 'v', 'r', 'q', 'd', 'v', 'v', 'wp']
#这是因为论文有两个encoder,之前的论文只有一个encoder,只需要问题文本作为输入
#第三种是前缀表达式,行如['-', '-', 'N0', '*', '-', 'N0', 'N1', 'N2', 'N1']
#第四种是后缀表达式,行如['N0', 'N0', 'N1', '-', 'N2', '*', '-', 'N1', '-']
for word in sentence:
if re.search(pattern='N\d+|NUM|\d+',string=word):
continue#数字和特殊字符NUM不作为encoder端的词汇
if word not in self.index2word:
self.word2index[word]=self.n_words
self.word2count[word]=1
self.index2word.append(word)
self.n_words+=1
else:
self.word2count[word]+=1
def trim(self,min_count):
'''
根据min_count去除词典中的单词,缩小词典的空间
'''
keep_words=[]
for word,freq in self.word2count.items():
if freq>=min_count:
#词频高的词保留
keep_words.append(word)
self.word2index={
}
self.word2count={
}
self.index2word=[]
self.n_words=0
for word in keep_words:
self.word2index[word]=self.n_words
self.index2word.append(word)
self.n_words+=1
def build_input_lang(self,trim_min_count):
if trim_min_count>0:
self.trim(min_count=trim_min_count)
self.index2word=['PAD','NUM','UNK']+self.index2word#因为删除了一些单词后,在训练集中自然会出现一些没有见过的单词
else:
self.index2word=['PAD','NUM']+self.index2word
#重置word2index,因为要考虑PAD和NUM以及UNK等特殊字符
self.word2index={
word:index for index,word in enumerate(self.index2word)}
def build_input_lang_for_pos(self):
#对于词性标注的输入,没有NUM需要考虑,而且不需要删除不常见单词
self.index2word=['PAD','UNK']+self.index2word#需要注意的是,调用这个函数的对象一定是词性标注输入的对象
self.n_words=len(self.index2word)
self.word2index={
word:index for index,word in enumerate(self.index2word)}
def build_output_lang(self,generate_nums,copy_nums):
'''
generate_nums代表的是常数,如: 1,3.14
copy_nums代表的是出现数字次数最多的那个问题出现的数字次数,copy_nums决定了decoder端最多可以预测多少个不同数字
'''
self.index2word+=['PAD','EOS']+generate_nums+['N'+str(i) for i in range(copy_nums)]+['SOS','UNK']
self.n_words=len(self.index2word)
self.word2index={
word:index for index,word in enumerate(self.index2word)}
def build_output_lang_for_tree(self,generate_nums,copy_nums):
'''
树形结构的decoder和sequence结构的decoder是不同的,因为tree结构不是序列式的生成表达式,所以不考虑PAD和EOS,SOS等
'''
self.num_start=len(self.index2word)
self.index2word+=generate_nums+['N'+str(i) for i in range(copy_nums)]+['UNK']
self.n_words=len(self.index2word)
self.word2index={
word:index for index,word in enumerate(self.index2word)}
验证一下
input1_lang = Lang()
input2_lang = Lang()
output1_lang = Lang()
output2_lang = Lang()
for pair in pairs_trained:
if pair[-1]:
input1_lang.add_sen_to_vocab(pair[1])#pair[1]是问题文本
input2_lang.add_sen_to_vocab(pair[2])#pair[2]是问题句子的词性
output1_lang.add_sen_to_vocab(pair[4])#pair[4]是前缀表达式
output2_lang.add_sen_to_vocab(pair[5])#pair[5]是后缀表达式
trim_min_count=5
input1_lang.build_input_lang(trim_min_count)
input2_lang.build_input_lang_for_pos()
output1_lang.build_output_lang_for_tree(generate_nums, copy_nums)
output2_lang.build_output_lang(generate_nums, copy_nums)
def indexes_from_sentence(lang,sentence,tree=False):
'''
根据lang中的word2index将sentence中的每一个token转为对应的id
这里面的sentence不一定是句子,也可能是词性标注序列,或者输出的前缀后缀表达式
'''
res=[]
unk_token=lang.word2index['UNK']
for token in sentence:
if len(token)==0:
continue
res.append(lang.word2index.get(token,unk_token))
if 'EOS' in lang.index2word and not tree:
#输出端有两个decoder,其中一个是sequence式结构,另一个是tree结构
#sequence结构中需要有'EOS'
res.append(lang.word2index['EOS'])
return res
def texts_from_sentence(lang, sentence, tree=False):
'''
函数的目的是将sentence中出现的词汇如果不在lang.word2index中,那么就换成UNK
'''
res = []
for word in sentence:
if len(word) == 0:
continue
if word in lang.word2index:
res.append(word)
else:
res.append("UNK")
if "EOS" in lang.index2word and not tree:
res.append(lang.word2index["EOS"])
return res
def num_list_processed(num_list):
'''
num_list代表的是一个问题中所有的数字
函数的目的是将num_list中的数字进一步换算成对应的值,同时将百分号等数字替换成对应的小数
将分数也同样计算成对应的小数
'''
st = []
for p in num_list:
pos1 = re.search("\d+\(", p)
pos2 = re.search("\)\d+", p)
if pos1:
st.append(eval(p[pos1.start(): pos1.end() - 1] + "+" + p[pos1.end() - 1:]))
elif pos2:
st.append(eval(p[:pos2.start() + 1] + "+" + p[pos2.start() + 1: pos2.end()]))
elif p[-1] == "%":
st.append(float(p[:-1]) / 100)
else:
st.append(eval(p))
return st
def num_order_processed(num_list):
'''
由于论文中提出要比较一个问题中所有数字的大小,所以这个函数的作用就是用整数来表达一个数字在当前这个问题中的所有
数字的大小,数值的大小代表的是这个数字大于多少个数字
'''
num_order = []
num_array = np.asarray(num_list)
for num in num_array:
num_order.append(sum(num>num_array)+1)
return num_order
def prepare_data(pairs_trained,pairs_tested,trim_min_count,generate_nums,copy_nums):
'''
pairs[0]-->id,问题样本id
pairs[1]-->input seq,问题文本
pairs[2]-->pos,问题单词的词性标注
pairs[3]-->parser,句法分析的结果
pairs[4]-->prefix expression
pairs[5]-->postfix expression
pairs[6]-->nums
pairs[7]-->nums_pos
'''
input1_lang = Lang()
input2_lang = Lang()
output1_lang = Lang()
output2_lang = Lang()
train_pairs = []
test_pairs = []
print("Indexing words...")
for pair in pairs_trained:
if pair[-1]:
input1_lang.add_sen_to_vocab(pair[1])
input2_lang.add_sen_to_vocab(pair[2])
output1_lang.add_sen_to_vocab(pair[4])
output2_lang.add_sen_to_vocab(pair[5])
input1_lang.build_input_lang(trim_min_count)
input2_lang.build_input_lang_for_pos()
output1_lang.build_output_lang_for_tree(generate_nums, copy_nums)
output2_lang.build_output_lang(generate_nums, copy_nums)
for pair in pairs_trained:
num_stack = []
for word in pair[4]:
#pair[4]是前缀表达式,行如['/', '*', 'N1', 'N2', '5']
temp_num = []
flag_not = True
#output1_lang是树形结构decoder的词空间
if word not in output1_lang.index2word:
#这种情况是因为前缀表达式中出现了数字,而我们知道,数字是不作为词空间中的元素的
#表达式中按理说所有的数字都已经被转为对应的Ni了,出现数字的原因是这个数字在问题中出现了多次
flag_not = False
for i, j in enumerate(pair[6]):
#pair[6]是nums,也就是每一个数字,行如 ['5', '16.5', '2.1', '5']
if j == word:
temp_num.append(i)#temp==[0,3],temp记录的是表达式中出现的重复的数字在nums中的位置
if not flag_not and len(temp_num) != 0:
num_stack.append(temp_num)
if not flag_not and len(temp_num) == 0:
num_stack.append([_ for _ in range(len(pair[6]))])
#num_stack.reverse()#实验表明,这行代码没有用
input1_cell = indexes_from_sentence(input1_lang, pair[1])#pair[1] is input_seq
texts_cell = texts_from_sentence(input1_lang, pair[1])
input2_cell = indexes_from_sentence(input2_lang, pair[2])#pair[2] is input seq pos
output1_cell = indexes_from_sentence(output1_lang, pair[4], True)#pair[4] is prefix_expression, used for tree-decoder
output2_cell = indexes_from_sentence(output2_lang, pair[5], False)#pair[5] is postfix expression,
num_list = num_list_processed(pair[6])#pair[6] is nums
num_order = num_order_processed(num_list)
train_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell),
output1_cell, len(output1_cell), output2_cell, len(output2_cell),
pair[6], pair[7], num_stack, num_order))
print('Indexed %d words in input language, %d words in output1, %d words in output2' %
(input1_lang.n_words, output1_lang.n_words, output2_lang.n_words))
print('Number of training data %d' % (len(train_pairs)))
for pair in pairs_tested:
num_stack = []
for word in pair[4]:
temp_num = []
flag_not = True
if word not in output1_lang.index2word:
flag_not = False
for i, j in enumerate(pair[6]):
if j == word:
temp_num.append(i)
if not flag_not and len(temp_num) != 0:
num_stack.append(temp_num)
if not flag_not and len(temp_num) == 0:
num_stack.append([_ for _ in range(len(pair[6]))])
num_stack.reverse()
input1_cell = indexes_from_sentence(input1_lang, pair[1])
texts_cell = texts_from_sentence(input1_lang, pair[1])
input2_cell = indexes_from_sentence(input2_lang, pair[2])
output1_cell = indexes_from_sentence(output1_lang, pair[4], True)
output2_cell = indexes_from_sentence(output2_lang, pair[5], False)
num_list = num_list_processed(pair[6])
num_order = num_order_processed(num_list)
test_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell),
output1_cell, len(output1_cell), output2_cell, len(output2_cell),
pair[6], pair[7], num_stack, num_order))
print('Number of testind data %d' % (len(test_pairs)))
return input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs
input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs = prepare_data(pairs_trained, pairs_tested, 5, generate_nums, copy_nums)
train_example=train_pairs[500]
print("example id : ",train_example[0])
print("example input seq (词频少的单词已经被替换成UNK): ",train_example[1])
print("将所有的单词替换成对应的id : ",train_example[2])
print("将标注的词性替换成对应的id : ",train_example[3])
print("句法分析的结构: ",train_example[4])
print("句子长度 : ",train_example[5])
print("将前缀表达式中的运算符替换成对应的id : ",train_example[6])
print("前缀表达式的长度 : ",train_example[7])
print("将后缀表达式中的运算符替换成对应的id : ",train_example[8])
print("后缀表达式的长度(后缀表达式是作为sequence decoder的标签,所以包含EOS,长度要更长一些) : ",train_example[9])
print('这个问题对应的所有的数字 : ',train_example[10])
print('这个问题中数字的位置 : ',train_example[11])
print('这个问题是否包含有重复数字,如果有,重复数字出现的位置 : ',train_example[12])
print('这个问题中所有数字的大小关系 : ',train_example[13])
def prepare_train_batch(pairs_to_batch,batch_size):
'''
这个函数用来构造输入数据
对于pairs_to_batch中的每一个元素example,都有14个字段,分别是
example id;example input seq (词频少的单词已经被替换成UNK);
example input_seq_id(所有的单词替换成对应的id);example pos_id(将标注的词性替换成对应的id);
example parse(句法分析的结构);example_length(句子长度);
example prefix_expression_id(将前缀表达式中的运算符替换成对应的id);prefix_expression length(前缀表达式的长度);
example postfix_expression_id(将后缀表达式中的运算符替换成对应的id);postfix_expression length(后缀表达式的长度);
example question nums(这个问题对应的所有的数字);example question nums_pos(这个问题中数字的位置);
example question num_stack(这个问题是否包含有重复数字,如果有,重复数字出现的位置);example question num_order(这个问题中所有数字的大小关系)
每一个example有14个字段
'''
pairs=deepcopy(pairs_to_batch)
random.shuffle(pairs)#随机打乱训练数据,因为我们要保证各个数据样本之间是相互独立的,满足iid条件
id_batches=[]#存储各个样本的id
input1_batches=[]#存储各个样本中问题对应的id(将问题文本中的单词转成id)
input2_batches=[]#存储各个样本中问题的每一个单词对应的词性标注对应的id
#input1和input2都是sequence encoder的输入
input_lengths=[]#存储各个样本中问题的长度
output1_lengths=[]#存储各个样本中问题对应的前缀表达式的长度
output2_lengths=[]#存储各个样本中问题对应的后缀表达式的长度
nums_batches=[]#存储各个样本中问题中出现的数字个数,也就是len(nums)
num_pos_batches=[]#对应的,存储各个样本中问题中出现的数字在问题中的索引
num_order_batches=[]#存储每一个问题中各个数字之间的大小关系
num_stack_batches=[]#如果问题中出现了重复数字,记录重复数字在nums中的位置,否则是[]
num_size_batches=[]
output1_batches = []
output2_batches = []
parse_graph_batches = []#存储句法解析
batches=[]#按照批次来存储数据,每一批数据为一个单词
num_of_batch=0
print()
print('一共有{}个训练数据样本,按照{}为批次大小,所以一共有{}个训练批次'.format(len(pairs),batch_size,len(pairs)//batch_size+1))
while num_of_batch+batch_size<len(pairs):
batches.append(pairs[num_of_batch:num_of_batch+batch_size])
num_of_batch+=batch_size
batches.append(pairs[num_of_batch:])
for batch in batches:
#在每一个批次中,按照这个批次的每一个句子的长度排序,句子长的放在前面,这样有助于后面的RNN编码
batch=sorted(batch,key=lambda example:example[5],reverse=True)#example[5]是句子长度
input_length=[]
output1_length=[]
output2_length=[]
for id_,input_seq,seq_id,pos_id,parse,seq_len,prefix_id,prefix_len,postfix_id,postfix_len,nums,nums_pos,num_stack,num_order in batch:
input_length.append(seq_len)
output1_length.append(prefix_len)
output2_length.append(postfix_len)
input_lengths.append(input_length)
output1_lengths.append(output1_length)
output2_lengths.append(output2_length)
input_len_max = input_length[0]#当前这个批次中所有问题长度的最大值
output1_len_max = max(output1_length)
output2_len_max = max(output2_length)
id_batch = []
input1_batch = []
input2_batch = []
output1_batch = []
output2_batch = []
num_batch = []
num_stack_batch = []
num_pos_batch = []
num_order_batch = []
num_size_batch = []
parse_tree_batch = []
for idx,input_seq,seq_id,pos_id,parse,seq_len,prefix_id,prefix_len,postfix_id,postfix_len,num,num_pos,num_stack,num_order in batch:
id_batch.append(idx)
seq_id+=[PAD_token for _ in range(input_len_max-seq_len)]#pad
pos_id+=[PAD_token for _ in range(input_len_max-seq_len)]#pad
input1_batch.append(seq_id)
input2_batch.append(pos_id)
prefix_id+=[PAD_token for _ in range(output1_len_max-prefix_len)]
postfix_id+=[PAD_token for _ in range(output2_len_max-postfix_len)]
#表达式同样需要pad
output1_batch.append(prefix_id)
output2_batch.append(postfix_id)
num_batch.append(len(num))#这个问题出现了多少个数字
num_stack_batch.append(num_stack)#是否有重复数字
num_pos_batch.append(num_pos)#数字的位置
num_order_batch.append(num_order)#数字之间的大小关系
num_size_batch.append(len(num_pos))
assert len(num)==len(num_pos)
parse_tree_batch.append(parse)
id_batches.append(id_batch)
input1_batches.append(input1_batch)
input2_batches.append(input2_batch)
output1_batches.append(output1_batch)
output2_batches.append(output2_batch)
nums_batches.append(num_batch)
num_stack_batches.append(num_stack_batch)
num_pos_batches.append(num_pos_batch)
num_order_batches.append(num_order_batch)
num_size_batches.append(num_size_batch)
parse_g=get_parse_graph_batch(input_length, parse_tree_batch)
assert type(parse_g)==np.ndarray
assert parse_g.shape==(len(batch),3,input_len_max,input_len_max)
parse_graph_batches.append(parse_g)
return id_batches, input1_batches, input2_batches, input_lengths, output1_batches, output1_lengths, output2_batches, output2_lengths, \
nums_batches, num_stack_batches, num_pos_batches, num_order_batches, num_size_batches, parse_graph_batches