系列目录:
未完待续 … …
上一篇文章对模型的结构进行了介绍,本文开始介绍训练中的数据准备,数据经过预处理后,到真正输入模型进行训练还需要进一步的处理。
首先来看一下训练的主函数,主函数train如下:
def train(args):
"""
训练阅读理解模型
"""
logger = logging.getLogger("brc")
logger.info('Load data_set and vocab...')
# 加载字典
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
# 加载数据
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
args.train_files, args.dev_files)
logger.info('Converting text into ids...')
# 将数据转换为数字索引ids
brc_data.convert_to_ids(vocab)
logger.info('Initialize the model...')
# 初始化模型
rc_model = RCModel(vocab, args)
logger.info('Training the model...')
# 训练模型
rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
save_prefix=args.algo,
dropout_keep_prob=args.dropout_keep_prob)
logger.info('Done with model training!')
有代码可以看到,训练主函数包含了加载词典、加载数据、将数据转换为索引、构建模型、训练模型几部分,本文重点介绍下其中加载数据部分。
BRCDataset函数在准备数据部分简单介绍过,回顾一下:
类名 BRCDataset
功能:实现加载使用百度阅读理解数据集的APIs
类属性:
self.max_p_num:最大段落数量
self.max_p_len:最大段落长度
self.max_q_len:最大问题长度
self.train_set, self.dev_set, self.test_set:训练、验证、测试数据集
类主要方法:
_load_dataset():加载数据,数据集初始化时会自动调用这个函数加载数据
_one_mini_batch:生成一个batch的数据
_dynamic_padding:动态填充
word_iter:遍历数据集中所有单词
convert_to_ids:将数据集中的文本(问题、文档)转化为ids
gen_mini_batches:为特定数据集生成batch数据
下面简单介绍其中数据处理的关键函数,其余的大家可以自行阅读源代码。
_load_dataset函数是在BRCDataset类初始化时自动运行,加载训练、验证、测试数据集数据,其代码如下:
def _load_dataset(self, data_path, train=False):
"""
加载数据集
Args:
data_path: 需要加载的数据集的路径
"""
with open(data_path) as fin:
data_set = []
for lidx, line in enumerate(fin):
# 开始处理单个样本
sample = json.loads(line.strip())
if train:
if len(sample['answer_spans']) == 0:
continue
if sample['answer_spans'][0][1] >= self.max_p_len:
continue
# 答案所在的文档,后面在_one_mini_batch函数中用于计算答案范围的偏置
if 'answer_docs' in sample:
sample['answer_passages'] = sample['answer_docs']
# 问题
sample['question_tokens'] = sample['segmented_question']
# 文档
sample['passages'] = []
# 遍历每个样本中的文档
for d_idx, doc in enumerate(sample['documents']):
if train:
# 如果是训练集,处理相对简单,只取预处理中计算的每个文档的最相关段落将其作为
#`passage_tokens`与`is_selected`组成的字典插入`passages`
most_related_para = doc['most_related_para']
sample['passages'].append(
{'passage_tokens': doc['segmented_paragraphs'][most_related_para],
'is_selected': doc['is_selected']}
)
else:
# 如果不是训练集,则遍历每个段落,计算段落与问题的recall值,
#并按照recall和段落长度排序(短的在前),取前几个段落作为passage_tokens
para_infos = []
for para_tokens in doc['segmented_paragraphs']:
question_tokens = sample['segmented_question']
# 计算段落与问题的recall值
common_with_question = Counter(para_tokens) & Counter(question_tokens)
correct_preds = sum(common_with_question.values())
if correct_preds == 0:
recall_wrt_question = 0
else:
recall_wrt_question = float(correct_preds) / len(question_tokens)
para_infos.append((para_tokens, recall_wrt_question, len(para_tokens)))
para_infos.sort(key=lambda x: (-x[1], x[2]))
fake_passage_tokens = []
# 取第一个段落作为passage_tokens
for para_info in para_infos[:1]:
fake_passage_tokens += para_info[0]
sample['passages'].append({'passage_tokens': fake_passage_tokens})
data_set.append(sample)
return data_set
由代码可见,_load_dataset函数在加载数据的同时对数据集(尤其是校验集和测试集)进行了进一步处理,为样本添加了answer_passages
、question_tokens
、passages
字段,其中passages
对于训练集是每个文档中与答案最相关段落的列表,对其他数据集是与问题最相关段落的列表。
gen_mini_batches可以为设定的数据集(train/dev/test)生成数据批次,训练中训练代码会调用这个函数来生成训练数据。
def gen_mini_batches(self, set_name, batch_size, pad_id, shuffle=True):
"""
为设定的数据集(train/dev/test)生成数据批次
参数:
set_name: 数据集名称,使用train/dev/test 指明数据集
batch_size: 每个批次样本的数量
pad_id: 填充字符索引
shuffle: 如果值为真,将数据打乱.
返回值:
所有批次的生成器
"""
if set_name == 'train':
data = self.train_set
elif set_name == 'dev':
data = self.dev_set
elif set_name == 'test':
data = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(set_name))
data_size = len(data)
indices = np.arange(data_size)
if shuffle:
np.random.shuffle(indices)
for batch_start in np.arange(0, data_size, batch_size):
batch_indices = indices[batch_start: batch_start + batch_size]
# 根据索引生成一个样本批次
yield self._one_mini_batch(data, batch_indices, pad_id)
由代码可见,这个函数主要的功能是选择数据集、打乱数据、确定每个批次样本索引,最终每一个批次数据的生成是调用了_one_mini_batch
函数。
_one_mini_batch根据输入的数据和所选索引生成一个数据批次,生成时还根据本批次的最长样本和设置的最大长度对这个批次的样本进行填充。
def _one_mini_batch(self, data, indices, pad_id):
"""
生成一个批次
参数:
data: 所有数据
indices: 所选样本的索引the indices of the samples to be selected
pad_id:填充字符索引
返回值:
一个数据批次
"""
batch_data = {'raw_data': [data[i] for i in indices],
'question_token_ids': [],
'question_length': [],
'passage_token_ids': [],
'passage_length': [],
'start_id': [],
'end_id': []}
# 最大段落数量
max_passage_num = max([len(sample['passages']) for sample in batch_data['raw_data']])
max_passage_num = min(self.max_p_num, max_passage_num)
for sidx, sample in enumerate(batch_data['raw_data']):
# 遍历1到`max_passage_num`
for pidx in range(max_passage_num):
# 如果pidx小于段落数量,即有样本,将样本值赋给batch_data的对应字段
if pidx < len(sample['passages']):
batch_data['question_token_ids'].append(sample['question_token_ids'])
batch_data['question_length'].append(len(sample['question_token_ids']))
passage_token_ids = sample['passages'][pidx]['passage_token_ids']
batch_data['passage_token_ids'].append(passage_token_ids)
batch_data['passage_length'].append(min(len(passage_token_ids), self.max_p_len))
# 如果没有样本,插入空样本
else:
batch_data['question_token_ids'].append([])
batch_data['question_length'].append(0)
batch_data['passage_token_ids'].append([])
batch_data['passage_length'].append(0)
# 动态填充批次数据,返回样本长度对齐的批次,及填充后的段落、问题长度
batch_data, padded_p_len, padded_q_len = self._dynamic_padding(batch_data, pad_id)
for sample in batch_data['raw_data']:
if 'answer_passages' in sample and len(sample['answer_passages']):
# 计算答案所在段落偏移,sample['answer_passages'][0]在_load_dataset中创建,是答案所在文档的索引
gold_passage_offset = padded_p_len * sample['answer_passages'][0]
# 根据偏移计算答案的起始索引和终止索引
batch_data['start_id'].append(gold_passage_offset + sample['answer_spans'][0][0])
batch_data['end_id'].append(gold_passage_offset + sample['answer_spans'][0][1])
else:
# 如果没有答案插入0
batch_data['start_id'].append(0)
batch_data['end_id'].append(0)
return batch_data
由代码可以看到,这个函数功能如下:
raw_data
字段。_dynamic_padding
函数对每一个段落进行填充操作,根据最大段落长度,截取或填充。batch_data
。import sys
import pickle
from run import *
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
* https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.
sys.argv = []
args = parse_args()
print(args)
Namespace(algo='BIDAF', batch_size=32, brc_dir='../data/baidu', dev_files=['../data/demo/devset/search.dev.json'], dropout_keep_prob=1, embed_size=300, epochs=10, evaluate=False, gpu='0', hidden_size=150, learning_rate=0.001, log_path=None, max_a_len=200, max_p_len=500, max_p_num=5, max_q_len=60, model_dir='../data/models/', optim='adam', predict=False, prepare=False, result_dir='../data/results/', summary_dir='../data/summary/', test_files=['../data/demo/testset/search.test.json'], train=False, train_files=['../data/demo/trainset/search.train.json'], vocab_dir='../data/vocab/', weight_decay=0)
# 创建数据集
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
args.train_files, args.dev_files)
# 打开词典
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
# 将样本文本转化为索引ids,并添加到数据集
brc_data.convert_to_ids(vocab)
# 准备参数,生成一个大小为4的批次
import numpy as np
data = brc_data.train_set
data_size = len(data)
indices = np.arange(data_size)
pad_id = vocab.get_id(vocab.pad_token)
batch_start = 0
batch_size = 4
batch_indices = indices[batch_start: batch_start + batch_size]
batch = brc_data._one_mini_batch(data, batch_indices,pad_id)
batch.keys()
dict_keys(['raw_data', 'question_token_ids', 'question_length', 'passage_token_ids', 'passage_length', 'start_id', 'end_id'])
由输出可见batch包含了以下字段:
其具体值如下:
print(batch['question_token_ids'])
print(np.shape(batch['question_token_ids']))
print(batch['question_length'])
print(np.shape(batch['passage_token_ids']))
print(batch['passage_length'])
print(batch['start_id'])
print(batch['end_id'])
[[2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0]]
(20, 5)
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2, 2, 2, 2, 2]
(20, 443)
[96, 147, 17, 51, 114, 31, 226, 12, 51, 443, 29, 204, 82, 279, 57, 404, 328, 118, 133, 27]
[1772, 4, 5, 445]
[1882, 11, 28, 601]
有输出可见,对于一个样本数量为4的批次,问题数据维度为(20, 5),文档数据维度为(20, 443),所以问题与文档的数量都是20,答案数量为4。这是因为加载数据时,代码根据预先设定的最大文档数量5,将每个样本的文档数量填充(空文档)为5个,同时将每个问题复制了5次,因此每个样本对应5个文档及问题。
另外可以看到一个批次中,所有问题与文档都被填充成相同的长度,长度大小取该批次所有文档(问题)的最大长度与预先设定的文档(文本)最大长度中较小的值。
参考文献:
DuReader数据集
DuReader Baseline Systems (基线系统)
BiDAF
Match-LSTM
Match-LSTM & BiDAF