网上找了个非官方的Transformer开源实现:https://github.com/Kyubyong/transformer。任务是英文德文的翻译。
1. 基础数据:
英语:
2.处理基础数据:
处理数据的方法都在data_load.py中。
load_data: 分别将德语和英语的句子读到sents1列表和sents2列表,并返回。
def load_data(fpath1, fpath2, maxlen1, maxlen2):
'''Loads source and target data and filters out too lengthy samples.
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
Returns
sents1: list of source sents
sents2: list of target sents
'''
sents1, sents2 = [], []
with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
for sent1, sent2 in zip(f1, f2):
if len(sent1.split()) + 1 > maxlen1: continue # 1:
if len(sent2.split()) + 1 > maxlen2: continue # 1:
sents1.append(sent1.strip())
sents2.append(sent2.strip())
return sents1, sents2
input_fn:
几个重要方法:
tf.data.Dataset.from_generator: https://blog.csdn.net/foreseerwang/article/details/80572182
dataset.padded_batch(batch_size, shapes, paddings):https://blog.csdn.net/yeqiustu/article/details/79795639
def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
'''Batchify data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
'''
shapes = (([None], (), ()),
([None], [None], (), ())) #与generator返回值对应
types = ((tf.int32, tf.int32, tf.string),
(tf.int32, tf.int32, tf.int32, tf.string)) #与上边的shapes是对应的
paddings = ((0, 0, ''),
(0, 0, 0, ''))
dataset = tf.data.Dataset.from_generator(
generator_fn,
output_shapes=shapes,
output_types=types,
args=(sents1, sents2, vocab_fpath)) # <- arguments for generator_fn. converted to np string arrays
if shuffle: # for training
dataset = dataset.shuffle(128*batch_size)
dataset = dataset.repeat() # iterate forever
dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)
return dataset
encode:将source和target中的语言字符根据词典转换成对应的编号,注意source处理后,返回值以数字开头,以结束, target处理后,返回值以开头,以结束。
def encode(inp, type, dict):
'''Converts string to number. Used for `generator_fn`.
inp: 1d byte array.
type: "x" (source side) or "y" (target side)
dict: token2idx dictionary
Returns
list of numbers
'''
inp_str = inp.decode("utf-8")
if type=="x": tokens = inp_str.split() + [""]
else: tokens = [""] + inp_str.split() + [""]
x = [dict.get(t, dict[""]) for t in tokens] #返回指定键值,否则返回默认值
return x
generator_fn: source侧返回(转换后的句子编码,句子编码长度,原始句子),target侧返回(decoder的输入,转换后的句子编码,句子编码长度,原始句子)。这里注意decoder_input,是不包含最后一个字符的句子编码(例如:23,56,.....),通过这里看出decoder输入的第一个字符是。
def generator_fn(sents1, sents2, vocab_fpath):
'''Generates training / evaluation data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
yields
xs: tuple of
x: list of source token ids in a sent
x_seqlen: int. sequence length of x
sent1: str. raw source (=input) sentence
labels: tuple of
decoder_input: decoder_input: list of encoded decoder inputs
y: list of target token ids in a sent
y_seqlen: int. sequence length of y
sent2: str. target sentence
'''
token2idx, _ = load_vocab(vocab_fpath)
for sent1, sent2 in zip(sents1, sents2):
x = encode(sent1, "x", token2idx)
y = encode(sent2, "y", token2idx)
decoder_input, y = y[:-1], y[1:]
x_seqlen, y_seqlen = len(x), len(y)
yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)
get_batch:对前边方法的封装,供其他模块调用。
def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
'''Gets training / evaluation mini-batches
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
batches
num_batches: number of mini-batches
num_samples
'''
sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
num_batches = calc_num_batches(len(sents1), batch_size)
return batches, num_batches, len(sents1)