(未完)
Transformer论文及框架分析:机器翻译Transformer框架分析笔记 | Attention is all you need
本文代码来源Github:kyubyong/transformer/tf1.2_legacy
作者已更新较新版本tensorflow对应的transformer代码,本笔记基于老代码
做笔记使用
代码 | 介绍 |
---|---|
hyperhparams.py | 超参数设定 |
prepro.py | 生成字典 |
data_load.py | 格式化数据,生成batch |
modules.py | 网络模型 |
train.py | 训练 |
eval.py | 评估 |
代码1:hyperparams.py 定义超参数文件
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
June 2017 by kyubyong park.
[email protected].
https://www.github.com/kyubyong/transformer
'''
class Hyperparams: #超参数
'''Hyperparameters'''
# data 训练集与测试集
source_train = 'corpora/train.tags.de-en.de'
target_train = 'corpora/train.tags.de-en.en'
source_test = 'corpora/IWSLT16.TED.tst2014.de-en.de.xml'
target_test = 'corpora/IWSLT16.TED.tst2014.de-en.en.xml'
# training
#batch_size调参重点
#mini-batch gradient decent,小批的梯度下降,这种方法把数据分为若干个批,按批来更新参数,这样,一个批中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。
batch_size = 32 # alias = N 在实际机翻训练过程中batch_size一般设置从4000—8000不等,要具体情况具体分析
lr = 0.0001 # learning rate. In paper, learning rate is adjusted to the global step.
# 在实际训练中,一般动态设置学习率,从大到小以达到细分精度找到“最优解”
logdir = 'logdir' # log directory
# model
maxlen = 10 # alias = T. 单词最大长度,实习训练中忽略此项的限制
# Feel free to increase this if you are ambitious.
#min_cnt调参
min_cnt = 20 # words whose occurred less than min_cnt are encoded as .
#调参重点
hidden_units = 512 # alias = C 隐藏节点
num_blocks = 6 # number of encoder/decoder blocks
num_epochs = 20 #迭代20次所有样本
num_heads = 8 #多头注意力机制中的层数H
dropout_rate = 0.1 #残差丢弃正则化,根据实际情况可继续增大
sinusoid = False # If True, use sinusoid. If false, positional embedding. 不使用正弦曲线
代码2:prepro.py 生成词汇表
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
June 2017 by kyubyong park.
[email protected].
https://www.github.com/kyubyong/transformer
'''
from __future__ import print_function #本机python2,使用python3的print()函数
from hyperparams import Hyperparams as hp #超参数
import tensorflow as tf
import numpy as np
import codecs #使用codecs.open()读写文件,避免编码不统一报错
import os
import regex #正则表达式
from collections import Counter #计数器
def make_vocab(fpath, fname): #生成词汇表
'''Constructs vocabulary.
Args:
fpath: A string. Input file path. 输入路径,训练集
fname: A string. Output file name. 输出路径,词汇表
Writes vocabulary line by line to `preprocessed/fname`
'''
text = codecs.open(fpath, 'r', 'utf-8').read() #用unicode编码方式读取
text = regex.sub("[^\s\p{Latin}']", "", text) #正则表达式,只保留英文单词
words = text.split()
word2cnt = Counter(words) #计数器,输出词典:key=单词,value=个数
if not os.path.exists('preprocessed'): os.mkdir('preprocessed') #输出路径
#使用with语句:不用close(),同时避免异常
#str.format()格式化函数,类似于print('',% )中的%
with codecs.open('preprocessed/{}'.format(fname), 'w', 'utf-8') as fout:
#先写入四个特殊词
#主要用来进行字符补全,编号0
#未登录词/低频词,编号1
#句子开始的标识,编号2
# 句子结尾的标识,编号3
fout.write("{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n".format("", "", "", ""))
#collections.Counter.most_common(N)按照频次从大到小排列词典,只显示前N个单词。
for word, cnt in word2cnt.most_common(len(word2cnt)):
fout.write(u"{}\t{}\n".format(word, cnt)) #按照 单词\t频次\n来写入
if __name__ == '__main__':
make_vocab(hp.source_train, "de.vocab.tsv")
make_vocab(hp.target_train, "en.vocab.tsv") #两个词汇表
print("Done")
代码3:data_load.py 格式化数据,生成batch
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
June 2017 by kyubyong park.
[email protected].
https://www.github.com/kyubyong/transformer
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import tensorflow as tf
import numpy as np
import codecs
import regex
#词汇表转化为字典格式,并删除频次较低的
def load_de_vocab(): #德语
# splitlines()按行切片(\n,\r,\r\n);line.split()[0]将遍历好的每一行(一个单词一个频次)列表化,并取第0个元素/单词;如果单词的频次大于我们在hyperhparams.py中设定的参数就保存该单词到一个列表中
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)} #转换为字典的形式进行保存单词,并给每个单词进行编号
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word
def load_en_vocab(): #英语
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word
'''
举例:
word2idx ={'': 0, '': 1, '': 2, '': 3, '有': 4,
'的': 5, '`': 6, '-': 7, '卦': 8, '八': 9, ..., '爬': 1642, 'U': 1643}
idx2word={{0: '', 1: '', 2: '', 3: '', 4: '有',
5: '的', 6: '`', 7: '-', 8: '卦', 9: '八', ..., 1642: '爬', 1643: 'U'}}
'''
#数据处理
def create_data(source_sents, target_sents): #source_sents存放源语言句子的列表, target_sents目标语言句子
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()
# Index 索引
x_list, y_list, Sources, Targets = [], [], [], []
for source_sent, target_sent in zip(source_sents, target_sents): #使用zip()函数同时遍历两个句子列表
#x,y 一个新句子
x = [de2idx.get(word, 1) for word in (source_sent + u" ").split()] # 1: OOV, : End of Text
y = [en2idx.get(word, 1) for word in (target_sent + u" ").split()] #给每一个句子的末尾加上终止符,并遍历句子中的每一个单词,将已经存在于word2idx中的那个单词对应的ID添加到新的列表中,如果这个单词不存在于word2idx中,那么就返回 ID‘1’到新列表中组成一个新的‘ID句子’(其中1代表)
if max(len(x), len(y)) <= hp.maxlen: #我们在hyperhparams.py中设置的最大句子长度
x_list.append(np.array(x)) #源语言ID句子
y_list.append(np.array(y)) #目标语言ID句子
Sources.append(source_sent) #源语言句子
Targets.append(target_sent) #目标语言句子
#超过长度阈值的丢弃
# Pad 填充 对应site特殊词中的编号0:
X = np.zeros([len(x_list), hp.maxlen], np.int32) #二维0矩阵:句子个数*最大句长
Y = np.zeros([len(y_list), hp.maxlen], np.int32)
for i, (x, y) in enumerate(zip(x_list, y_list)):
#保证每个ID句子的长度/元素个数都是相同的
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0)) #对每一个ID句子做填充,左侧填充0个0,右侧填充hp.maxlen-len(x)个0,并且0也是四个特殊词中的一个:编号0
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0))
return X, Y, Sources, Targets #X和Y的shape为[len(x_list), hp.maxlen],Sources, Targets的shape为[1, len(x_list)]
#加载训练集,对训练集做数据处理,返回定长ID句子
def load_train_data():
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.source_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] #不加载特殊词
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.target_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Y
#加载测试集,对测试集做数据处理,返回定长ID句子
def load_test_data():
def _refine(line):
line = regex.sub("<[^>]+>", "", line) #删除所有非空的< >项
line = regex.sub("[^\s\p{Latin}']", "", line)
return line.strip() #删除字符串首尾的指定字符(默认为空格)
#读取源语言与目标语言文本,切片并处理;将句子开头前4个字符为