data_dir 存放原始数据,
def main(unused_argv):
del unused_argv # Unused
corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset) #
save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
if not tf.gfile.Exists(save_dir):
tf.gfile.MakeDirs(save_dir)
# test mode
if FLAGS.per_host_test_bsz > 0:
corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
FLAGS.tgt_len, FLAGS.num_core_per_host,
FLAGS=FLAGS)
return
for split, batch_size in zip(
["train", "valid"],
[FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):
if batch_size <= 0: continue
print("Converting {} set...".format(split))
corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
FLAGS.num_core_per_host, FLAGS=FLAGS)
读取字典,字典会使用pickle序列化存储在磁盘中。初次获取字典时,会创建
创建Corpus主要有四步:
1、count_file,读取原文中每一行内容,去除首尾的空格和换行\n,然后逐字拆分为数组,数组中添加< eos >标记,统计每一个词的出现次数记录在counter = Counter(),
2、使用build_vocab创建词汇表,把统计的所有词根据asic编码排序,去除低频词汇
3、add_symbol,原始符号与索引的映射–sym2idx,索引到原始词缀的映射idx2sym(按照顺序,数组下标既是索引)
def get_lm_corpus(data_dir, dataset):
fn = os.path.join(data_dir, "cache.pkl")
if tf.gfile.Exists(fn):
print("Loading cached dataset...")
with open(fn, "rb") as fp:
corpus = pickle.load(fp)
else:
print("Producing dataset...")
kwargs = {}
kwargs["special"] = [""]
kwargs["lower_case"] = False
corpus = Corpus(data_dir, dataset, **kwargs)
print("Saving dataset...")
with open(fn, "wb") as fp:
pickle.dump(corpus, fp, protocol=2)
corpus_info = {
"vocab_size": len(corpus.vocab),
"cutoffs": corpus.cutoffs,
"dataset": corpus.dataset
}
with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
json.dump(corpus_info, fp)
return corpus
class Vocab(object):
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
delimiter=None, vocab_file=None):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
self.max_size = max_size
self.lower_case = lower_case
self.delimiter = delimiter
self.vocab_file = vocab_file
self.idx2sym = []
self.sym2idx = OrderedDict() # todo 确定这里有没有问题
# for zhihu dataset
# todo delete here when test other datasets
# self.min_freq = 100
# self.add_symbol('')
# self.unk_idx = self.get_idx('')
def tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
symbols = list(line)
if add_double_eos: # lm1b
# 确保 在symbol list 中能找
self.add_symbol('')
return [''] + symbols + ['']
elif add_eos:
return symbols + ['']
else:
return symbols
# 取出file 中的sentences
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
assert tf.gfile.Exists(path)
sents = []
with open(path, 'r',encoding='UTF-8') as f:
# 读取每一行的内容
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=True)
self.counter.update(symbols)
sents.append(symbols)
return sents
# 更新counter 中的token
def count_sents(self, sents, verbose=False):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: print('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
self.counter.update(symbols)
def _build_from_file(self, vocab_file):
# self.idx2sym = []
# self.sym2idx = OrderedDict()
with open(vocab_file, 'r') as f:
for line in f:
symb = line.strip().split()[0]
self.add_symbol(symb)
self.unk_idx = self.sym2idx['']
# 建立vocab, 将symbol 保存
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self)))
else:
print('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
self.add_special("")
# todo 这里巨坑!!!!!
# for sym, cnt in self.counter.most_common(self.max_size):
# if cnt < self.min_freq:
# break
tmp = sorted(self.counter.items(), key=lambda item:item[0])
for sym, cnt in tmp:
if cnt < self.min_freq:
continue
self.add_symbol(sym)
print('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))
# 主要在于convert_to_nparray, 其实也就是将vocab变成idx
def encode_file(self, path, ordered=False, verbose=False,
add_double_eos=False):
if verbose: print('encoding file {} ...'.format(path))
assert tf.gfile.Exists(path)
encoded = []
with open(path, 'r',encoding="utf-8") as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=True, add_double_eos=add_double_eos)
encoded.append(self.convert_to_nparray(symbols))
if ordered:
encoded = np.concatenate(encoded)
return encoded
#
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents)))
encoded = []
symbols = self.tokenize(sents)
encoded.append(self.convert_to_nparray(symbols))
if ordered:
encoded = np.concatenate(encoded)
return encoded
def add_special(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
def add_symbol(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx):
assert 0 <= idx < len(self.idx2sym), 'Index {} out of range'.format(idx)
return self.idx2sym[idx]
def get_idx(self, sym):
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
assert hasattr(self, 'unk_idx')
return self.sym2idx.get(sym, self.unk_idx)
def get_symbols(self, indices):
return [self.get_sym(idx) for idx in indices]
def get_indices(self, symbols):
return [self.get_idx(sym) for sym in symbols]
# 字转index
def convert_to_nparray(self, symbols):
nparray = np.array(self.get_indices(symbols), dtype=np.int64)
return nparray
# index转字
def convert_to_sent(self, indices, exclude=None):
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
def __len__(self):
return len(self.idx2sym)