菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(十)—— 模型前向计算数据流动

系列目录:

  1. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)——数据
  2. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)——
    介绍及分词
  3. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理
  4. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(四)—— 段落抽取
  5. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(五)—— 准备数据
  6. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(六)—— 模型构建
  7. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(七)—— 模型训练-数据准备
  8. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练
  9. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(九)—— 预测与校验

到上一篇文章就完成了DuReader的基线模型代码的简要介绍,本文主要是展示一下数据在模型中前向计算流动时数据维度的变化,希望帮助大家对模型有个更深入的理解。

变量准备

import sys
import pickle
from run import *
import logging
import tensorflow as tf

# 准备参数
sys.argv = []
args = parse_args()

# 设定日志输出
logger = logging.getLogger("brc")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if args.log_path:
    file_handler = logging.FileHandler(args.log_path)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
else:
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

logger.info('Running with args : {}'.format(args))

#设定运行环境变量
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

2020-03-28 23:31:05,607 - brc - INFO - Running with args : Namespace(algo='BIDAF', batch_size=32, brc_dir='../data/baidu', dev_files=['../data/demo/devset/search.dev.json'], dropout_keep_prob=1, embed_size=300, epochs=10, evaluate=False, gpu='0', hidden_size=150, learning_rate=0.001, log_path=None, max_a_len=200, max_p_len=500, max_p_num=5, max_q_len=60, model_dir='../data/models/', optim='adam', predict=False, prepare=False, result_dir='../data/results/', summary_dir='../data/summary/', test_files=['../data/demo/testset/search.test.json'], train=False, train_files=['../data/demo/trainset/search.train.json'], vocab_dir='../data/vocab/', weight_decay=0)

准备一个batch的数据

为模型前向计算准备一个批次的数据,这里我们将数据集的batch_size改为4方便计算及数据展示。

# 加载词典
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    vocab = pickle.load(fin)

# 加载数据集
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                      args.train_files, args.dev_files)
brc_data.convert_to_ids(vocab)

pad_id = vocab.get_id(vocab.pad_token)
# 生成数据集批次生成器,batch_size设置为4
batch_size = 4
train_batches = brc_data.gen_mini_batches('train', batch_size, pad_id, shuffle=False)
# 从批次生成器中取一个批次的数据
for batch in train_batches:
    break
2020-03-28 23:31:05,680 - brc - INFO - Train set size: 95 questions.
2020-03-28 23:31:05,797 - brc - INFO - Dev set size: 100 questions.

构建模型、创建feed_dict

# 构建模型
rc_model = RCModel(vocab, args)

total_num, total_loss = 0, 0
# 创建feed_dict作为模型的输入
feed_dict = {rc_model.p: batch['passage_token_ids'],
             rc_model.q: batch['question_token_ids'],
             rc_model.p_length: batch['passage_length'],
             rc_model.q_length: batch['question_length'],
             rc_model.start_label: batch['start_id'],
             rc_model.end_label: batch['end_id'],
             rc_model.dropout_keep_prob: 1.0}
2020-03-28 23:31:09,393 - brc - INFO - Time to build graph: 3.3495678901672363 s
2020-03-28 23:31:17,457 - brc - INFO - There are 4995603 parameters in the model

emb层输出

p_emb,q_emb = rc_model.sess.run([rc_model.p_emb, rc_model.q_emb], feed_dict)
print(p_emb.shape)
print(q_emb.shape)
(20, 443, 300)
(20, 5, 300)

由输出可以看到emb层的输出的维度为(批次大小序列长度嵌入层维度)即(batch_sizeseq_lengthembed_dim)为(20, 443/5, 300)。示例中,我们设定一个batch大小为4个样本,由于每个样本由五个文档组成,因此被拆分为五列,所以这里batch_size变成了20。序列长度为这一批次完成填充后每个样本的长度。

编码层输出

sep_p_encodes,sep_q_encodes = rc_model.sess.run([rc_model.sep_p_encodes, rc_model.sep_q_encodes], feed_dict)
print(sep_p_encodes.shape)
print(sep_q_encodes.shape)
(20, 443, 300)
(20, 5, 300)

由输出可以看到,经过编码层后数据维度没有变,其中最后一个维度还是300是因为模型设定的hidden_size为150,所以其隐藏状态输出维度为150,而双向lstm的输出拼接后,变成300。

match层输出

match_p_encodes = rc_model.sess.run(rc_model.match_p_encodes, feed_dict)
match_p_encodes.shape
(20, 443, 1200)

match层是bidaf算法的核心层,在该层的最后输出是由[passage_encodes, context2question_attn,passage_encodes * context2question_attn,passage_encodes * question2context_attn]四个矩阵拼接而成,因此该层的模型输出特征维度由输入的300变为1200。该层的详细介绍在模型构建部分,感兴趣的可以去了解一下。

fuse层输出

fuse_p_encodes = rc_model.sess.run(rc_model.fuse_p_encodes, feed_dict)
fuse_p_encodes.shape
(20, 443, 300)

fuse层通过调用rnn函数使用双向LSTM对包含了问题-文档融合信息的特征编码进行了进一步的融合,由于参数设置lstm层的hidden_size为150,所以双向LSTM层的输出特征维度为300。

decode层输出

decode中首先将相同样本的5个段落进行了拼接,代码中拼接后的变量为concat_passage_encodes,首先来查看一下这个变量。

batch_size = tf.shape(rc_model.start_label)[0]
concat_passage_encodes = tf.reshape(
    rc_model.fuse_p_encodes,
    [batch_size, -1, 2 * rc_model.hidden_size]
)
concat_passage_encodes = rc_model.sess.run(concat_passage_encodes, feed_dict)
concat_passage_encodes.shape
(4, 2215, 300)

由输出可以看到,concat_passage_encodesbatch_size为4,与样本数量相同。序列长度为2215,是一个样本所有段落长度的和。特征值维度为300。然后调用PointerNetDecoder对其进行处理,生成起始索引、终止索引概率分布。

由于其输入还需要问题解码数据,首先看下对其进行的处理。

sep_q_encodes.shape
(20, 5, 300)
# 去除问题编码的重复项
no_dup_question_encodes = tf.reshape(
    rc_model.sep_q_encodes,
    [batch_size, -1, tf.shape(rc_model.sep_q_encodes)[1], 2 * rc_model.hidden_size]
)[0:, 0, 0:, 0:]
# 输出结果
no_dup_question_encodes = rc_model.sess.run(no_dup_question_encodes, feed_dict)
no_dup_question_encodes.shape
(4, 5, 300)

可以看到,代码将重复的问题编码去掉,只保留一行,使问题编码的batch_size变为与问题编码一致。然后就可以用来生成起始索引、终止索引概率分布了。

start_probs,end_probs = rc_model.sess.run([rc_model.start_probs, rc_model.end_probs], feed_dict)
print(start_probs.shape)
print(end_probs.shape)
(4, 2215)
(4, 2215)

由输出可以看到,一个样本对应的模型的最终输出是两个长度为2215的向量,这两个向量长度与样本填充后所有段落长度的和一致,每个位置的值分布对应该位置为起始索引或终止索引的概率。

参考文献:

  • DuReader数据集

  • DuReader Baseline Systems (基线系统)

  • BiDAF

  • Match-LSTM

  • Match-LSTM & BiDAF

你可能感兴趣的:(NLP,#,机器阅读理解,自然语言处理,tensorflow,深度学习,神经网络)