for split, batch_size in zip(
["train", "valid"],
[FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):
if batch_size <= 0: continue
print("Converting {} set...".format(split))
corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
FLAGS.num_core_per_host, FLAGS=FLAGS)
per_host_train_bsz每次训练所取的大小
record_name 存放记录的文件名
data在创建corpus时从train.txt中读取的输入内容,已经转换为向量。
def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
num_core_per_host, **kwargs):
FLAGS = kwargs.get('FLAGS')
file_names = []
use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1)
if use_tpu:
record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
split, bsz, tgt_len, num_core_per_host)
else:
record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
split, bsz, tgt_len)
record_info_path = os.path.join(save_dir, record_name)
if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "tangshi", "doupo", "test", "zhihu", "poetry","zhuxian","xuezhongqijie","longzu"]:
data = getattr(self, split)
bin_sizes = get_bin_sizes(
data, bsz // num_core_per_host, tgt_len, self.cutoffs)
file_name, num_batch = create_ordered_tfrecords(
save_dir, split, data, bsz, tgt_len, num_core_per_host,
self.cutoffs, bin_sizes,
num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1,
use_tpu=use_tpu)
file_names.append(file_name)
with open(record_info_path, "w") as fp:
record_info = {
"filenames": file_names,
"bin_sizes": bin_sizes,
"num_batch": num_batch
}
json.dump(record_info, fp)
逐行逐列处理,
共bgz行,
inputs 当前行,t, t+tgt_len个元素
labels 当前行,t+1后面的元素,也就是inputs的后一个字的一句话。
将inputs保存在tf.train.Example中,然后序列化,使用record_writer存储预处理之后的数据。
def create_ordered_tfrecords(save_dir, basename, data, batch_size, tgt_len,
num_core_per_host, cutoffs=[], bin_sizes=[],
num_passes=1, use_tpu=False):
# save_dir 就是tfrecord的路径
if use_tpu:
file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format(
basename, batch_size, tgt_len, num_core_per_host)
else:
file_name = "{}.bsz-{}.tlen-{}.tfrecords".format(
basename, batch_size, tgt_len)
save_path = os.path.join(save_dir, file_name)
record_writer = tf.python_io.TFRecordWriter(save_path)
batched_data = batchify(data, batch_size, num_passes)
num_batch = 0
for t in range(0, batched_data.shape[1] - 1, tgt_len):
#当前的tgt_len,如果不够tgt_len,则取小的
cur_tgt_len = min(batched_data.shape[1] - 1 - t, tgt_len)
# drop the remainder if use tpu
if use_tpu and cur_tgt_len < tgt_len:
break
if num_batch % 500 == 0:
print(" processing batch {}".format(num_batch))
for idx in range(batch_size):
inputs = batched_data[idx, t:t + cur_tgt_len]
labels = batched_data[idx, t + 1:t + cur_tgt_len + 1]
# features dict
feature = {
"inputs": _int64_feature(inputs),
"labels": _int64_feature(labels),
}
if len(cutoffs) > 0 and use_tpu:
# validate `bin_sizes` and `cutoffs`
assert len(cutoffs) - len(bin_sizes) == 2, \
"len(cutoffs) - len(bin_sizes) != 2"
# mask for bin 0
left, right = cutoffs[:2]
inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32)
tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32)
feature["inp_mask"] = _float_feature(inp_mask)
feature["tgt_mask"] = _float_feature(tgt_mask)
# refresh `inp_cnts` and `tgt_cnts` for each TPU core
if idx % (batch_size // num_core_per_host) == 0:
inp_cnts = [0] * len(bin_sizes)
tgt_cnts = [0] * len(bin_sizes)
head_labels = np.copy(labels)
inp_pos_per_bin, tgt_pos_per_bin = [], []
for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):
inp_pos = np.where((inputs >= left) * (inputs < right))[0]
tgt_pos = np.where((labels >= left) * (labels < right))[0]
inp_pos_per_bin.append(inp_pos)
tgt_pos_per_bin.append(tgt_pos)
head_labels[tgt_pos] = cutoffs[1] + b
feature["head_labels"] = _int64_feature(head_labels)
# permutation feature
def _add_perm_feature(feature, pos_per_bin, cnts, prefix):
for b, pos in enumerate(pos_per_bin):
idx_tuple = []
for p in pos:
if cnts[b] < bin_sizes[b]:
idx_tuple.append([p, cnts[b]])
cnts[b] += 1
else:
break
n_tup = len(idx_tuple)
tup = np.array(idx_tuple).reshape(n_tup * 2)
feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup])
feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup)
_add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp")
_add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt")
example = tf.train.Example(features=tf.train.Features(feature=feature))
record_writer.write(example.SerializeToString())
num_batch += 1
record_writer.close()
print("Done writing {}. batches: {}".format(file_name, num_batch))
return file_name, num_batch