系列目录:
准备数据包括检查数据、创建文件夹、准备词典、准备词嵌入。具体实现见基线系统tensorflow/run.py文件中prepare函数。
准备数据的过程中,代码用到了基线系统定义的BRCDataset和vocab类,先简单介绍一下。
基线系统定义了BRCDataset类,用来读取数据,构建数据集,具体实现见tensorflow\dataset.py
具体属性、方法如下:
类名 BRCDataset
功能:实现加载使用百度阅读理解数据集的APIs
类属性:
self.max_p_num:最大段落数量
self.max_p_len:最大段落长度
self.max_q_len:最大问题长度
self.train_set, self.dev_set, self.test_set:训练、验证、测试数据集
类主要方法:
_load_dataset():加载数据,数据集初始化时会自动调用这个函数加载数据
_one_mini_batch:生成一个batch的数据
_dynamic_padding:动态填充
word_iter:遍历数据集中所有单词
convert_to_ids:将数据集中的文本(问题、文档)转化为ids
gen_mini_batches:为特定数据集生成batch数据
在数据准备阶段只需要加载数据以及使用word_iter方法遍历数据集中的单词提供给vocab。
Vocab类主要用于构建数据集字典,可以用它将单词转化为数字ids,具体实现见tensorflow\vocab.py
具体属性、方法如下:
类名 Vocab
功能:实现词典,保存数据集中的单词以及其对应的词嵌入
类属性:
self.id2token:字典,id到单词的映射
self.token2id:字典,单词到id的映射
self.token_cnt:数据集中每个单词的数量
self.lower:是否将字母小写
self.embed_dim:嵌入向量维度
self.embeddings:词嵌入矩阵
类主要方法:
size:获得词典大小
load_from_file:从文件中读取词典
get_id:输入单词,获得单词对应id,如果单词不存在返回unk token的id
get_token:输入id,返回对应的单词,如果id不存在返回unk token
add:输入单词,将其加入字典
filter_tokens_by_cnt:通过词频过滤单词,输入最小词频,小于这个词频的单词被过滤掉
randomly_init_embeddings:为所有单词随机初始化嵌入向量
load_pretrained_embeddings:输入embedding_path,从中加载预训练的嵌入向量,如果没有则将其过滤掉
convert_to_ids:输入单词列表,将其转化为ids
recover_from_ids:输入ids列表,将其转化为单词,如果输入了stop_id,当其出现时停止转化
函数实现见/tensorflow/run.py。
def prepare(args):
"""
检查数据,创建文件夹,准备词典和词嵌入
"""
#设定运行日志log
logger = logging.getLogger("brc")
logger.info('Checking the data files...')
#检查数据文件是否存在
for data_path in args.train_files + args.dev_files + args.test_files:
assert os.path.exists(data_path), '{} file does not exist.'.format(data_path)
#准备目录,创建保存词典、模型、结果、训练摘要的目录
logger.info('Preparing the directories...')
for dir_path in [args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir]:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
#创建词典
logger.info('Building vocabulary...')
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
args.train_files, args.dev_files, args.test_files)
vocab = Vocab(lower=True)
for word in brc_data.word_iter('train'):
vocab.add(word)
unfiltered_vocab_size = vocab.size()
#删除词频低于2的词
vocab.filter_tokens_by_cnt(min_cnt=2)
filtered_num = unfiltered_vocab_size - vocab.size()
logger.info('After filter {} tokens, the final vocab size is {}'.format(filtered_num,
vocab.size()))
#随机初始化词嵌入矩阵
logger.info('Assigning embeddings...')
vocab.randomly_init_embeddings(args.embed_size)
#保存词典
logger.info('Saving vocab...')
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'wb') as fout:
pickle.dump(vocab, fout)
logger.info('Done with preparing!')
在demo数据上运行的结果如下图所示:
由图可知,demo数据集训练集有95个问题,校验集有100个问题,测试集有100个问题。词典过滤掉5225个单词,最后剩余5006个单词。
参考文献:
DuReader数据集
DuReader Baseline Systems (基线系统)