本文主要是对上篇没加attention的一个补充,attention实际上是模仿人类翻译的过程,在翻译一个句子时,有时需要根据上下文判断当前要翻译的单词的含义,那么就需要去时时查看下原本的句子,因为句子中有些部分会对当前单词预测的影响很大,那么得把这样的信息加入到预测当前单词的过程中。
假如encoder的输入为[X1,...,Xj,...,XTx],即输入句子的最大长度为s,不够时padding至s,Xj是第j个单词的embedding。将此输入送入encoder中,得到enc_output为[H1,...,Hj,...,HTx]。在decoder中,已经得到状态依次为[S1,...,Si-1],当前正在预测第i个词(第i个输出yi),那么将Si-1与enc_output的各个元素进行相关度计算,也就是eij = a(Si-1,Hj),然后计算出attention_layer的权重也就是eij(j=1...Tx)形成的数组的softmax输出αij,之后将这些权重与enc_output进行加权平均得到Ci,并与Si-1,yi-1(第i-1个输出)作为参数共同计算出Si。如下图所示:
下面直接上加了attention之后的训练与预测代码,这里注意一下,因为恢复模型时,可能报kernel notFound Error,那么在写预测代码时需要加上tf.variable_scope限定,但限定范围词需要通过保存模型的网络结够来确定,cat_net.py是一个辅助查看保存模型网络结构的代码。
################
# This code used to check msg of Tensor stored in ckpt
# work well with tensorflow version of 'v1.3.0-rc2-20-g0787eee'
################
import os
from tensorflow.python import pywrap_tensorflow
# code for finall ckpt
# checkpoint_path = os.path.join('~/tensorflowTraining/ResNet/model', "model.ckpt")
# code for designated ckpt, change 3890 to your num
checkpoint_path = "./seq2seq_attention_ckpt-9000"
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key))
直接输入下面的命令:
python cat_net.py > net.log
然后打开net.log内容如下:
tensor_name: nmt_model/trg_emb
[[-0.0108806 0.06109285 0.00103101 ... 0.04327971 -0.0353515
0.05592098]
[ 0.05511344 0.10315175 -0.03260459 ... 0.01503692 -0.05173029
0.00153936]
[ 0.0991787 -0.06907252 0.21693856 ... 0.0805049 -0.01262149
-0.01293714]
...
[-0.12422938 0.01872133 -0.08084115 ... 0.03637449 -0.0386718
0.0702277 ]
[-0.09422345 -0.0029713 -0.00904827 ... -0.03110654 -0.00099467
-0.0079839 ]
[ 0.15127543 -0.10549527 0.00927421 ... 0.00116051 0.11979865
0.02227078]]
tensor_name: nmt_model/softmax_bias
[ 9.702319 -1.8773284 4.8174777 ... -0.48294842 -0.35829535
-0.4816328 ]
tensor_name: encoder/bidirectional_rnn/fw/basic_lstm_cell/kernel
[[-0.03092953 -0.03392044 -0.02407785 ... 0.02163492 -0.0049458
0.02973264]
[ 0.10684907 0.04035901 0.01169399 ... 0.02350369 0.02541667
-0.0220029 ]
[ 0.01336335 0.0200959 0.00845157 ... -0.01780637 0.01966156
0.00902852]
...
[-0.01916422 -0.0131671 0.0082262 ... -0.01099342 -0.00506847
0.0146405 ]
[ 0.0169474 0.02184602 -0.01979198 ... -0.00957554 -0.01252236
0.03171452]
[-0.03693858 0.01639441 -0.02785428 ... -0.02872299 0.01957132
-0.02001939]]
tensor_name: decoder/memory_layer/kernel
[[ 0.04521498 -0.00092734 0.00987301 ... -0.01601705 -0.01625223
-0.00826636]
[ 0.03350661 -0.01258853 0.03047631 ... -0.01902125 -0.01759247
0.01519862]
[-0.02057176 -0.01262629 -0.00525282 ... -0.03981094 0.03607614
0.00477269]
...
[-0.0367771 -0.02705046 0.01810684 ... -0.03925494 0.03783213
-0.01419215]
[-0.01111888 0.00990444 0.02161855 ... 0.0041062 0.02929579
-0.00364193]
[ 0.02131032 0.00671287 0.00193167 ... -0.02134871 -0.00051426
0.02360947]]
tensor_name: decoder/rnn/attention_wrapper/multi_rnn_cell/cell_1/basic_lstm_cell/bias
[-1.6907989 -0.80868345 -1.1245108 ... -0.7377462 -1.0939049
-1.2807418 ]
tensor_name: nmt_model/src_emb
[[ 0.06317458 -0.05404264 -0.00954251 ... -0.14450565 -0.11939629
-0.05514779]
[ 0.00680785 0.04471309 -0.0104601 ... -0.03551793 -0.04758103
0.01540864]
[ 0.32627714 0.0827579 -0.11642702 ... -0.03501745 -0.27873012
-0.04998838]
...
[-0.0220207 -0.03215215 -0.01608298 ... -0.03651857 -0.04046999
-0.02552509]
[ 0.00540233 0.03604389 0.06067114 ... 0.05810086 0.03965386
0.06954922]
[ 0.02887495 -0.02881782 0.05515011 ... 0.03075846 0.00961011
-0.02850782]]
tensor_name: decoder/rnn/attention_wrapper/attention_layer/kernel
[[-0.09316745 -0.07995477 -0.0146741 ... 0.0717198 0.02371014
-0.05503882]
[-0.00638354 -0.05642074 -0.12752905 ... 0.07572 0.02780477
0.02916634]
[-0.0532836 0.01808308 -0.01555931 ... -0.08836221 -0.05027555
0.01292556]
...
[-0.03378733 0.01676184 -0.01945874 ... 0.04151832 -0.04257954
-0.00394057]
[-0.04521075 0.02617629 -0.01065068 ... 0.06043241 0.02765347
-0.03455104]
[-0.02321909 -0.0051408 0.02175523 ... 0.00103944 0.03563083
0.04527191]]
tensor_name: decoder/rnn/attention_wrapper/multi_rnn_cell/cell_0/basic_lstm_cell/bias
[-0.47364303 -0.43505263 -0.2991495 ... -0.34608215 -0.3425427
-0.41822633]
tensor_name: decoder/rnn/attention_wrapper/multi_rnn_cell/cell_1/basic_lstm_cell/kernel
[[-0.06655399 -0.033209 0.00741314 ... -0.03744704 0.16143945
-0.04238527]
[-0.09054025 -0.05978451 -0.0919419 ... -0.05676661 -0.03161845
0.11375111]
[-0.01762006 -0.01342999 0.00538671 ... -0.07151254 0.00439914
0.0617904 ]
...
[ 0.01361352 -0.00989851 -0.01075909 ... 0.02791671 0.0204173
0.03272137]
[-0.02172133 0.01065003 0.02755076 ... 0.01163509 0.00617506
0.02474814]
[-0.02055892 -0.0032329 -0.01226626 ... -0.03111863 0.04921816
-0.01788351]]
tensor_name: encoder/bidirectional_rnn/fw/basic_lstm_cell/bias
[-1.0207075 -0.7382192 -0.75269985 ... -0.7253135 -0.83074564
-0.71001625]
tensor_name: decoder/rnn/attention_wrapper/bahdanau_attention/attention_v
[ 0.00795415 -0.00872286 -0.02835944 ... 0.02541727 -0.0316006
-0.01547218]
tensor_name: decoder/rnn/attention_wrapper/bahdanau_attention/query_layer/kernel
[[-0.02799078 0.00915903 -0.00178415 ... -0.01649223 -0.02163657
0.01371716]
[-0.0445041 0.00936891 0.02943462 ... -0.04068676 -0.00589912
-0.05063123]
[ 0.01968101 0.03777748 0.01904894 ... -0.04097166 0.05280968
0.04113906]
...
[ 0.01412237 0.02355416 0.03901715 ... 0.01330961 0.01638247
0.00222727]
[ 0.02915935 0.00618351 0.01156276 ... 0.04674264 0.04458835
0.01011846]
[-0.00728581 0.04162799 -0.01898116 ... -0.03135163 -0.04987657
0.03854783]]
tensor_name: decoder/rnn/attention_wrapper/multi_rnn_cell/cell_0/basic_lstm_cell/kernel
[[-0.0266049 0.06239759 0.03370405 ... 0.00847407 0.02729598
-0.02040454]
[-0.04149583 -0.03149587 -0.01089299 ... -0.03426768 0.0172292
-0.05368057]
[ 0.01183772 0.09243455 -0.02107698 ... -0.05690235 0.0284145
-0.0332344 ]
...
[-0.02697257 -0.06419387 -0.04755762 ... 0.09542636 -0.01003412
-0.04204182]
[-0.04266602 -0.045127 0.02201566 ... -0.08180676 -0.01398551
-0.00633448]
[ 0.01584598 0.01223975 0.03658367 ... 0.02622196 -0.00311522
-0.00781288]]
tensor_name: encoder/bidirectional_rnn/bw/basic_lstm_cell/bias
[-0.80520725 -0.88913244 -1.0078353 ... -0.7981011 -0.65148497
-0.9233699 ]
tensor_name: encoder/bidirectional_rnn/bw/basic_lstm_cell/kernel
[[-0.05317495 0.00402797 -0.04864402 ... 0.04332062 0.02639003
-0.00492012]
[ 0.03982998 0.00540096 0.09128776 ... -0.03405574 0.00860246
0.01108253]
[ 0.00027926 0.00077254 0.08196697 ... 0.03171543 0.03697995
0.00165045]
...
[ 0.01058249 0.00307607 0.00184137 ... 0.00661535 0.01547921
-0.02362307]
[ 0.00757162 0.0162105 -0.01197527 ... -0.0082445 -0.00365599
0.03383213]
[-0.02791 0.00413945 -0.06630697 ... -0.0176604 0.01094399
-0.03434239]]
根据net.log的内容,我们就可以确定限定范围名称,主要是因为我们之前在声明结点时没有主动限定范围,然后又不知道tf默认的限定范围。下面是train_attention.py:
#coding:utf-8
import tensorflow as tf
MAX_LEN = 50
SOS_ID = 1
SRC_TRAIN_DATA = "../train.tags.en-zh.en.deletehtml.segment.id"
TRG_TRAIN_DATA = "../train.tags.en-zh.zh.deletehtml.segment.id"
CHECKPOINT_PATH = "./seq2seq_attention_ckpt"
HIDDEN_SIZE = 1024
NUM_LAYERS = 2
SRC_VOCAB_SIZE = 10000
TRG_VOCAB_SIZE = 4000
BATCH_SIZE = 100
NUM_EPOCH = 5
KEEP_PROB = 0.8
MAX_GRAD_NORM = 5
SHARE_EMB_AND_SOFTMAX = True
class NMTModel(object):
def __init__(self):
self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\
for _ in range(NUM_LAYERS)])
self.src_embedding = tf.get_variable(
"src_emb",[SRC_VOCAB_SIZE,HIDDEN_SIZE])
self.trg_embedding = tf.get_variable(
"trg_emb",[TRG_VOCAB_SIZE,HIDDEN_SIZE])
if SHARE_EMB_AND_SOFTMAX:
self.softmax_weight = tf.transpose(self.trg_embedding)
else:
self.softmax_weight = tf.get_variable("weight",[HIDDEN_SIZE,TRG_VOCAB_SIZE])
self.softmax_bias = tf.get_variable("softmax_bias",[TRG_VOCAB_SIZE])
def forward(self,src_input,src_size,trg_input,trg_label,trg_size):
batch_size = tf.shape(src_input)[0]
src_emb = tf.nn.embedding_lookup(self.src_embedding,src_input)
trg_emb = tf.nn.embedding_lookup(self.trg_embedding,trg_input)
src_emb = tf.nn.dropout(src_emb,KEEP_PROB)
trg_emb = tf.nn.dropout(trg_emb,KEEP_PROB)
with tf.variable_scope("encoder"):
enc_outputs,enc_state = tf.nn.bidirectional_dynamic_rnn(
self.enc_cell_fw,self.enc_cell_bw,src_emb,src_size,dtype=tf.float32)
enc_outputs = tf.concat([enc_outputs[0],enc_outputs[1]],-1)
with tf.variable_scope("decoder"):
self.attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(HIDDEN_SIZE,enc_outputs,memory_sequence_length=src_size)
self.attention_cell = tf.contrib.seq2seq.AttentionWrapper(self.dec_cell,self.attention_mechanism,attention_layer_size=HIDDEN_SIZE)
dec_outputs, _ = tf.nn.dynamic_rnn(
self.attention_cell,trg_emb,trg_size,dtype=tf.float32)
output = tf.reshape(dec_outputs,[-1,HIDDEN_SIZE])
logits = tf.matmul(output,self.softmax_weight) + self.softmax_bias
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(trg_label,[-1]),logits=logits)
label_weights = tf.sequence_mask(trg_size,maxlen=tf.shape(trg_label)[1],dtype=tf.float32)
label_weights = tf.reshape(label_weights,[-1])
cost = tf.reduce_sum(loss*label_weights)
cost_per_token = cost / tf.reduce_sum(label_weights)
trainable_variables = tf.trainable_variables()
grads = tf.gradients(cost / tf.to_float(batch_size), trainable_variables)
grads,_ = tf.clip_by_global_norm(grads,MAX_GRAD_NORM)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = optimizer.apply_gradients(zip(grads,trainable_variables))
return cost_per_token,train_op
def run_epoch(session,cost_op,train_op,saver,step):
while True:
try:
cost,_ = session.run([cost_op,train_op])
if step%10 == 0:
print("steps %d, per token cost is %.3f"%(step,cost))
if step%200 == 0:
saver.save(session,CHECKPOINT_PATH,global_step=step)
step += 1
except tf.errors.OutOfRangeError:
break
return step
def MakeDataset(file_path):
dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(lambda string: tf.string_split([string]).values)
dataset = dataset.map(lambda string: tf.string_to_number(string,tf.int32))
dataset = dataset.map(lambda x: (x,tf.size(x)))
return dataset
def MakeSrcTrgDataset(src_path,trg_path,batch_size):
src_data = MakeDataset(src_path)
trg_data = MakeDataset(trg_path)
dataset = tf.data.Dataset.zip((src_data,trg_data))
def FilterLength(src_tuple,trg_tuple):
((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)
src_len_ok = tf.logical_and(tf.greater(src_len,1),tf.less_equal(src_len,MAX_LEN))
trg_len_ok = tf.logical_and(tf.greater(trg_len,1),tf.less_equal(trg_len,MAX_LEN))
return tf.logical_and(src_len_ok,trg_len_ok)
dataset = dataset.filter(FilterLength)
def MakeTrgInput(src_tuple,trg_tuple):
((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)
trg_input = tf.concat([[SOS_ID],trg_label[:-1]],axis=0)
return ((src_input,src_len),(trg_input,trg_label,trg_len))
dataset = dataset.map(MakeTrgInput)
dataset = dataset.shuffle(10000)
padded_shapes = (
(tf.TensorShape([None]),
tf.TensorShape([])),
(tf.TensorShape([None]),
tf.TensorShape([None]),
tf.TensorShape([])))
batched_dataset = dataset.padded_batch(batch_size,padded_shapes)
return batched_dataset
def main():
initializer = tf.random_uniform_initializer(-0.05,0.05)
with tf.variable_scope("nmt_model",reuse=None,initializer=initializer):
train_model = NMTModel()
data = MakeSrcTrgDataset(SRC_TRAIN_DATA,TRG_TRAIN_DATA,BATCH_SIZE)
iterator = data.make_initializable_iterator()
(src,src_size),(trg_input,trg_label,trg_size) = iterator.get_next()
cost_op,train_op = train_model.forward(src,src_size,trg_input,trg_label,trg_size)
saver = tf.train.Saver()
step = 0
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7,allow_growth=True)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
with session as sess:
tf.global_variables_initializer().run()
for i in range(NUM_EPOCH):
print("In iteration: %d"%(i+1))
sess.run(iterator.initializer)
step = run_epoch(sess,cost_op,train_op,saver,step)
if __name__ == '__main__':
main()
下面是eval_attention.py,原书中并没有给出,尤其要注意那个修改的限定范围:
#coding:utf-8
import tensorflow as tf
CHECKPOINT_PATH = "./seq2seq_attention_ckpt-9800"
HIDDEN_SIZE = 1024
NUM_LAYERS = 2
SRC_VOCAB_SIZE = 10000
TRG_VOCAB_SIZE = 4000
BATCH_SIZE = 100
SHARE_EMB_AND_SOFTMAX = True
SOS_ID = 1
EOS_ID = 2
class NMTModel(object):
def __init__(self):
self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\
for _ in range(NUM_LAYERS)])
self.src_embedding = tf.get_variable(
"src_emb",[SRC_VOCAB_SIZE,HIDDEN_SIZE])
self.trg_embedding = tf.get_variable(
"trg_emb",[TRG_VOCAB_SIZE,HIDDEN_SIZE])
if SHARE_EMB_AND_SOFTMAX:
self.softmax_weight = tf.transpose(self.trg_embedding)
else:
self.softmax_weight = tf.get_variable("weight",[HIDDEN_SIZE,TRG_VOCAB_SIZE])
self.softmax_bias = tf.get_variable("softmax_bias",[TRG_VOCAB_SIZE])
def inference(self,src_input):
src_size = tf.convert_to_tensor([len(src_input)],dtype=tf.int32)
src_input = tf.convert_to_tensor([src_input],dtype=tf.int32)
src_emb = tf.nn.embedding_lookup(self.src_embedding,src_input)
with tf.variable_scope("encoder"):
enc_outputs,enc_state = tf.nn.bidirectional_dynamic_rnn(
self.enc_cell_fw,self.enc_cell_bw,src_emb,src_size,dtype=tf.float32)
enc_outputs = tf.concat([enc_outputs[0],enc_outputs[1]],-1)
MAX_DEC_LEN = 100
init_array = tf.TensorArray(dtype=tf.int32,size=0,dynamic_size=True,clear_after_read=False)
init_array = init_array.write(0,SOS_ID)
with tf.variable_scope("decoder"):
self.attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(HIDDEN_SIZE,enc_outputs,memory_sequence_length=src_size)
with tf.variable_scope("decoder/rnn/attention_wrapper"):
self.attention_cell = tf.contrib.seq2seq.AttentionWrapper(self.dec_cell,self.attention_mechanism,attention_layer_size=HIDDEN_SIZE)
state = self.attention_cell.zero_state(batch_size=1, dtype=tf.float32)
init_loop_var = (state,init_array,0)
def continue_loop_condition(state,trg_ids,step):
return tf.reduce_all(tf.logical_and(tf.not_equal(trg_ids.read(step),EOS_ID),tf.less(step,MAX_DEC_LEN-1)))
def loop_body(state,trg_ids,step):
trg_input = [trg_ids.read(step)]
trg_emb = tf.nn.embedding_lookup(self.trg_embedding,trg_input)
dec_outputs,next_state = self.attention_cell.call(state=state,inputs=trg_emb)
output = tf.reshape(dec_outputs,[-1,HIDDEN_SIZE])
logits = (tf.matmul(output,self.softmax_weight) + self.softmax_bias)
next_id = tf.argmax(logits,axis=1,output_type=tf.int32)
trg_ids = trg_ids.write(step+1,next_id[0])
return next_state,trg_ids,step+1
state,trg_ids,step = tf.while_loop(
continue_loop_condition,loop_body,init_loop_var)
return trg_ids.stack()
def main():
from stanfordcorenlp import StanfordCoreNLP
nlp = StanfordCoreNLP("../../stanford-corenlp-full-2018-10-05",lang='en')
with tf.variable_scope("nmt_model",reuse=None):
model = NMTModel()
vocab_file = "../train.tags.en-zh.en.deletehtml.vocab"
sentence = "It doesn't belong to mine!"
with open(vocab_file,'r') as f:
data = f.readlines()
words = [w.strip() for w in data]
word_to_id = {k:v for (k,v) in zip(words,range(len(words)))}
wordlist = nlp.word_tokenize(sentence.strip()) + [""]
# print(wordlist)
idlist = [str(word_to_id[w]) if w in word_to_id else str(word_to_id[""]) for w in wordlist]
idlist = [int(i) for i in idlist]
# print(idlist)
output_op = model.inference(idlist)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7,allow_growth=True)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
saver = tf.train.Saver()
saver.restore(session,CHECKPOINT_PATH)
output = session.run(output_op)
vocab_file2 = "../train.tags.en-zh.zh.deletehtml.vocab"
with open(vocab_file2,'r') as f2:
data2 = f2.readlines()
words = [w.strip() for w in data2]
id_to_word = {k:v for (k,v) in zip(range(len(words)),words)}
print([id_to_word[i] for i in output])
session.close()
nlp.close()
if __name__ == '__main__':
main()
输出结果为:
['', '这', '不', '是', '我', '的', '', '']