本文是“Attention-over-Attention Neural Networks for Reading Comprehension”的阅读笔记。这篇论文所处理的任务是阅读理解里面的完形填空问题。其模型架构是建立在“Text Understanding with the Attention Sum Reader Network”这篇论文至上。该论文首先提出了将Attention用于完形填空任务,本篇论文则在其基础之上添加了一个额外的Attention层,可以免去启发式的算法和一些超参数调整等问题。我们接下来结合两篇论文进行介绍。
首先介绍一下数据集,目前常用的大规模数据集主要包括CNN/Daliy Mail和Children’s Book Test(CBTest)。前面两个是新闻数据集,将一整篇新闻文档作为完形填空的文本(Document),然后将其新闻摘要中的一句话去掉一个词之后作为查询(Query),去掉的那个词作为答案(Answer)。其中Document中的命名实体会被替换成不同的标识符:@entity1、@entity2、、、等例如,第一行为网页URL(无用),第三行为Document, 第五行为Query, 第七行为answer,并且其中的命名实体均被替换:
CBT数据集是从儿童读物中获取,由于其没有摘要,所以采用,前面连续的21句话作为Document,第22句话作为Query等方式构建。然后其还根据答案的词性分为四个子集:命名实体(NE)、公共名词(CN)、动词、介词。但是由于后面两种答案与文本并没有十分紧密的关系,比如人们常常不需要读文本就可以判断出介词填空等,所以常用的是前面两种。
最终每条数据被构建为如下三元组:
<D, Q, A>
首先我们可以看一下“Text Understanding with the Attention Sum Reader Network”这篇论文所提出的模型架构,如下图所示:
从上图可以看出,模型首先通过嵌入矩阵V得到Document和Query中每个单词的词向量e(w)。接下来分别使用两个encoder网络获得文本中每个单词的向量contextual embedding
和Query的表示向量。这里的encoder使用的是双向GRU循环神经网络。然后使用点积的方式将Query向量和每一个单词的contextual embedding
相乘,得到的结果可以视为每个单词对于该查询的权重,亦可理解为attention。最后使用softmax函数将权重转化为归一化的概率,将概率最大的结果视为该query的答案。
接下来我们再看一下本文提出的模型架构,如下图所示:
模型的前半部分与上面完全一样,差别在于本文提出了一种“Attention over Attention”的机制,也就是获得Document和Query的向量之后,不将Query的所有单词合为一个向量,而是直接以矩阵的形式与Document矩阵相乘,然后分别从行和列两个维度对相乘后的矩阵进行softmax操作得到document的注意力矩阵和query的注意力矩阵。在对query矩阵每一列的元素进行求和当做权重,对document的attention矩阵进行点积即可。
其实模型使用tensorflow实现的时候十分简单,直接调用tf.contrib.rnn下面的GRUCell即可,难点在于数据的处理和读取操作。这里我们可以参考github上面的两个实现方案:OlavHN,marshmelloX。第一个使用了TF内置的读取数据的API,代码十分简洁明了,我有时间需要研究一下其实现原理整理出一份博客来。第二个使用的是传统的数据处理方式,也可以参考,此外在github上面应该可以找到CNN等数据集的处理代码结合着一起学习。但是上面两个代码实现都用的是比较老的版本,如果用的是tf1.0及以上的版本可能会出现一些函数的不兼容问题,我参照第一份代码实现进行了一定的修改,可以再1。0的版本上运行。代码后续会放到我的github上面,欢迎查看。在服务器上跑需要四五天的样子,现在还没跑完==下图是结果截图:
四个参数分别代表步数,错误率,准确度,时间。可以看到准确度不是十分稳定,但是基本上达到了论文里面提到的效果。可以看一下我修改过之后的model的代码,特别是模型构建部分还是比较简单的,只用了几行命令就实现了:
import os
import time
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import sparse_ops
from util import softmax, orthogonal_initializer
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('vocab_size', 119662, 'Vocabulary size')
flags.DEFINE_integer('embedding_size', 384, 'Embedding dimension')
flags.DEFINE_integer('hidden_size', 256, 'Hidden units')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('epochs', 2, 'Number of epochs to train/test')
flags.DEFINE_boolean('training', True, 'Training or testing a model')
flags.DEFINE_string('name', 'lc_model', 'Model name (used for statistics and model path')
flags.DEFINE_float('dropout_keep_prob', 0.9, 'Keep prob for embedding dropout')
flags.DEFINE_float('l2_reg', 0.0001, 'l2 regularization for embeddings')
model_path = 'models/' + FLAGS.name
if not os.path.exists(model_path):
os.makedirs(model_path)
def read_records(index=0):
train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)
queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'document': tf.VarLenFeature(tf.int64),
'query': tf.VarLenFeature(tf.int64),
'answer': tf.FixedLenFeature([], tf.int64)
})
document = sparse_ops.serialize_sparse(features['document'])
query = sparse_ops.serialize_sparse(features['query'])
answer = features['answer']
document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
[document, query, answer], batch_size=FLAGS.batch_size,
capacity=2000,
min_after_dequeue=1000)
sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)
document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1)
query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1)
return document_batch, document_weights, query_batch, query_weights, answer_batch
def inference(documents, doc_mask, query, query_mask):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, FLAGS.embedding_size],
initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05))
regularizer = tf.nn.l2_loss(embedding)
doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob)
doc_emb.set_shape([None, None, FLAGS.embedding_size])
query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob)
query_emb.set_shape([None, None, FLAGS.embedding_size])
with tf.variable_scope('document', initializer=orthogonal_initializer()):
fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
doc_len = tf.reduce_sum(doc_mask, reduction_indices=1)
h, _ = tf.nn.bidirectional_dynamic_rnn(
fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32)
#h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)
h_doc = tf.concat(h, 2)
with tf.variable_scope('query', initializer=orthogonal_initializer()):
fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
query_len = tf.reduce_sum(query_mask, reduction_indices=1)
h, _ = tf.nn.bidirectional_dynamic_rnn(
fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32)
#h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)
h_query = tf.concat(h, 2)
M = tf.matmul(h_doc, h_query, adjoint_b=True)
M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1)))
alpha = softmax(M, 1, M_mask)
beta = softmax(M, 2, M_mask)
#query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1)
query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1)
s = tf.squeeze(tf.matmul(alpha, query_importance), [2])
unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size))
y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s])
return y_hat, regularizer
def train(y_hat, regularizer, document, doc_weight, answer):
# Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206
# This unfortunately causes us to expand a sparse tensor into the full vocabulary
index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer)
flat = tf.reshape(y_hat, [-1])
relevant = tf.gather(flat, index)
# mean cause reg is independent of batch size
loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer
global_step = tf.Variable(0, name="global_step", trainable=False)
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer)))
optimizer = tf.train.AdamOptimizer()
grads_and_vars = optimizer.compute_gradients(loss)
capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars]
train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step)
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
return loss, train_op, global_step, accuracy
def main():
dataset = tf.placeholder_with_default(0, [])
document_batch, document_weights, query_batch, query_weights, answer_batch = read_records(dataset)
y_hat, reg = inference(document_batch, document_weights, query_batch, query_weights)
loss, train_op, global_step, accuracy = train(y_hat, reg, document_batch, document_weights, answer_batch)
summary_op = tf.summary.merge_all()
with tf.Session() as sess:
summary_writer = tf.summary.FileWriter(model_path, sess.graph)
saver_variables = tf.all_variables()
if not FLAGS.training:
saver_variables = filter(lambda var: var.name != 'input_producer/limit_epochs/epochs:0', saver_variables)
saver_variables = filter(lambda var: var.name != 'smooth_acc:0', saver_variables)
saver_variables = filter(lambda var: var.name != 'avg_acc:0', saver_variables)
saver = tf.train.Saver(saver_variables)
sess.run([
tf.initialize_all_variables(),
tf.initialize_local_variables()])
model = tf.train.latest_checkpoint(model_path)
if model:
print('Restoring ' + model)
saver.restore(sess, model)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
accumulated_accuracy = 0
try:
if FLAGS.training:
while not coord.should_stop():
loss_t, _, step, acc = sess.run([loss, train_op, global_step, accuracy], feed_dict={dataset: 0})
elapsed_time, start_time = time.time() - start_time, time.time()
print(step, loss_t, acc, elapsed_time)
if step % 100 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
if step % 1000 == 0:
saver.save(sess, model_path + '/aoa', global_step=step)
else:
step = 0
while not coord.should_stop():
acc = sess.run(accuracy, feed_dict={dataset: 2})
step += 1
accumulated_accuracy += (acc - accumulated_accuracy) / step
elapsed_time, start_time = time.time() - start_time, time.time()
print(accumulated_accuracy, acc, elapsed_time)
except tf.errors.OutOfRangeError:
print('Done!')
finally:
coord.request_stop()
coord.join(threads)
'''
import pickle
with open('counter.pickle', 'r') as f:
counter = pickle.load(f)
word, _ = zip(*counter.most_common())
'''
if __name__ == "__main__":
main()