对Pytorch的Seq2Seq这6篇论文进行精读,今天重新开始,第一篇,《Sequence to Sequence Learning with Neural Networks》Sutskever, I., O. Vinyals and Q.V. Le, Sequence to Sequence Learning with Neural Networks. 2014.
Google发表于2014年,全文链接
摘要
虽然DNN很牛逼,但是仍然无法完成从句子到句子的映射。
这篇论文提出一个通用端到端学习方法,对序列结构做出最小假设。
- 结构,使用LSTM将输入句子映射为一个固定维度(fixed dimensionality / fixed-sized)向量。使用另一个LSTM对向量进行解码。
- 结果,使用WMT-14数据集的英-法翻译任务
- 模型在长句上没有遇到困难
最后发现,在源句中颠倒单词顺序,能够提高LSTM的成绩,因为这种操作会在源语句和目标句子之间引入许多短依赖关系,这使得优化问题更容易(似乎现在NLP的GAN有一项就是调整语序)。
1. 介绍
略过
2. 模型
输入句子
输出句子
标准RNN可以通过迭代下面的公式来计算输出序列
但是问题来了,如何应对输入和输入长度不一样的句子?并且句子具有复杂的关系
一个解决方法:使用RNN将输入句子映射到一个固定长度的向量中,然后使用另一个RNN将向量映射到目标句子。但是这样的模型在长文本中进行训练是困难的,幸好LSTM出现了。
LSTM的目标就是估算条件概率,其中输入句子,输出句子为什么是,是因为输出和输入的长度可能不一样。
在方程中,每个p分布都是使用softmax处理词汇中所有单词结果来表示,每个句子都有一个的结束符,这样就能确定句子的长度。
实际模型在三个重要方面与上述描述不同。
- 首先,使用了两个不同的LSTM:一个用于输入序列,另一个用于输出序列。
- 其次,发现深LSTM明显优于浅LSTM,因此选择了具有四层的LSTM。
- 第三,发现扭转输入句子的单词顺序是非常有价值的。
3. 模型的实现
这张图是流程图,输入德语“guten morgen”,在绿色的encoder中被编码为一个一个词,在句首和句尾增加作为标签。
- 每一个时间步,encoder的输入是当前单词和上一时间步的隐藏状态
- 每一个时间步,encoder的输出是新的隐藏状态
可以将隐藏状态当成表示句子的向量。这样公式就出来了。
这里的RNN可以是任何卷积结构(LSTM或是GRU)。
当输入句子最后一个单词传入RNN后,这时的隐藏状态就是上下文向量,在这里表示为就是示意图中中间的那个z。
有了向量,可以开始对目标句子进行解码,生成目标语言的句子。这样decoder的公式也有了。
在decoder中,我们从隐藏状态转到实际单词,每一个时间步都使用来进行预测
。
4. 模型代码
4.1 引入相关库
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import random
import math
import time
# 设定SEED,让之后随机数生成一致
SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
# torch.backends.cudnn.benchmark = True 在程序刚开始加这条语句可以提升一点训练速度,没什么额外开销。
torch.backends.cudnn.deterministic=True
spacy_de=spacy.load('de')
spacy_en=spacy.load('en')
def tokenize_de(text):
# 使用[::-1]将文本进行倒序排列
return [tok.text for tok in spacy_de.tokenizer(text)][::-1]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
SRC=Field(
tokenize=tokenize_de,
init_token='',
eos_token='',
lower=True
)
TRG=Field(
tokenize=tokenize_en,
init_token='',
eos_token='',
lower=True
)
train_data, valid_data, test_data=Multi30k.splits(exts=('.de','.en'),fields=(SRC,TRG))
print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")
print(vars(train_data.examples[1]))
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}")
Number of training examples: 29000
Number of validation examples: 1014
Number of testing examples: 1000
{'src': ['.', 'antriebsradsystem', 'ein', 'bedienen', 'schutzhelmen', 'mit', 'männer', 'mehrere'], 'trg': ['several', 'men', 'in', 'hard', 'hats', 'are', 'operating', 'a', 'giant', 'pulley', 'system', '.']}
Unique tokens in source (de) vocabulary: 7855
Unique tokens in target (en) vocabulary: 5893
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)
BATCH_SIZE=128
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
device = device)
4.2 构建模型
在前面完成对数据的处理后,现在开始构建模型,我们之前也这么做了,按照教程一步步走下来,证明是可以的。
但是,如果要你再写一遍的话,你会发现依旧写不出来,问题在哪里?
我个人感觉是对模型吃的不透,包括传入数据的结构和数据的处理流程。所以,这里选择这个只有三个模块的seq2seq来研究。
模型包括三个部分,encoder、decoder和seq2seq(整合部分)。seq2seq的操作流程:
- 使用RNN(LSTM/GRU)对输入的句子(源语)进行编码,生成独立向量
- 独立向量就是上下文向量,可以把这个上下文向量作为输入句子的抽象表示
- 由第二个RNN(LSTM/GRU)对独立向量进行解码,通过一次生成一个字来学习输出目标句子
实现也会分成三个模块(encoder、decoder、seq2seq)来实现。在之前我们都会按照encoder->decoder->seq2seq的顺序来做,这样复合从具体到抽象的逻辑,但是我个人感觉搞到最后seq2seq的时候一头雾水,对输入的数据结构不了解。
这次换一下,从模型训练参数->seq2seq->encoder->decoder,我们看看搞进去的数据是什么样子。
4.2.1 模型配置
INPUT_DIM = len(SRC.vocab) # 模型输入维度,输入encoder的one-hot向量维度,就是根据源语数据集搞出来的词汇表中单词个数
# print(len(SRC.vocab))
OUTPUT_DIM = len(TRG.vocab) # 模型输出维度,输入到Decoder的one-hot向量,就是根据目标语数据集搞出来的词汇表单词个数
# print(len(TRG.vocab))
ENC_EMB_DIM = 256 # encoder的嵌入层维度,将one-hot向量转为密度向量
DEC_EMB_DIM = 256 # decoder的嵌入层维度,将one-hot向量转为密度向量
HID_DIM = 512 # 隐藏层和cell状态维度
N_LAYERS = 4 # 搞一个四层的
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
在搞定数据输入之后我们可以来看一下这次处理的数据样式。使用enumerate函数来枚举train_data和train_iterator。
对于一个可迭代的(iterable)/可遍历的对象(如列表、字符串),enumerate将其组成一个索引序列,利用它可以同时获得索引和值。enumerate多用于在for循环中得到计数
可以看到train_data和train_iterator中所包含的数据是两种类型(list和tensor)
在之前的处理中,我们使用Multi30k的splits对SRC和TRG进行了训练、测试、验证集划分处理,生成train_data等三个数据集。
使用BucketIterator.splits对train_data进行处理
BucketIterator是torchtext最强大的功能之一。它会自动将输入序列进行shuffle并做bucket。
这个功能强大的原因是——正如我前面提到的——我们需要填充输入序列使得长度相同才能批处理。
这里介绍一个小技巧,在anaconda中,经常会因为打印信息长度问题只保留头尾,可以使用set_printoptions方法,这个方法在pandas、numpy、torch中都有。
tensor可以使用shape来查看tensor的行数和列数,这里是28行、128列,也就是说batch_size就是列数,目标句子长度max_len就是行数。
torch.set_printoptions(threshold = 1e6)
for i , batch in enumerate(train_data):
if i <1:
print(i)
src=batch.src
trg=batch.trg
print(type(src))
print(src)
# print(trg)
else: break
for i , batch in enumerate(train_iterator):
if i <1:
print(i)
src=batch.src
trg=batch.trg
print(type(src))
print(src.shape)
print(src)
print(src.shape[0])
print(src.shape[1])
# print(trg)
else: break
0
['.', 'büsche', 'vieler', 'nähe', 'der', 'in', 'freien', 'im', 'sind', 'männer', 'weiße', 'junge', 'zwei']
0
torch.Size([29, 128])
tensor([[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2],
[ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 29, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4],
[ 0, 3810, 65, 344, 762, 90, 4062, 368, 681, 0, 5343, 2391,
118, 92, 507, 2615, 63, 34, 21, 1522, 0, 1640, 715, 63,
837, 2779, 117, 72, 647, 23, 0, 215, 90, 21, 344, 54,
3034, 0, 29, 7436, 5, 48, 299, 638, 297, 1114, 235, 933,
1692, 6879, 48, 3367, 693, 122, 476, 235, 1534, 72, 286, 0,
375, 301, 81, 141, 4703, 1716, 1803, 7475, 1101, 0, 271, 60,
84, 72, 1674, 29, 0, 181, 113, 123, 1188, 27, 581, 3386,
3119, 4317, 328, 228, 186, 1089, 1921, 90, 21, 0, 11, 6836,
547, 4583, 80, 90, 5438, 916, 4559, 481, 21, 21, 932, 3389,
2248, 2044, 233, 3826, 740, 230, 0, 389, 80, 343, 420, 34,
1628, 316, 2475, 12, 129, 58, 1738, 63],
[ 82, 189, 105, 0, 197, 17, 102, 33, 248, 282, 149, 8,
20, 525, 14, 6, 126, 1576, 379, 14, 8, 6, 118, 254,
2700, 19, 2690, 34, 5, 839, 18, 19, 17, 6882, 7429, 1049,
419, 168, 629, 529, 248, 2316, 5, 14, 86, 151, 39, 1462,
11, 1956, 144, 8, 116, 24, 10, 304, 4360, 34, 5, 0,
139, 24, 15, 6, 1217, 334, 1020, 555, 75, 7, 61, 117,
144, 1488, 17, 2611, 34, 67, 14, 20, 11, 882, 12, 28,
8, 506, 638, 1947, 405, 46, 19, 17, 1957, 19, 352, 4268,
55, 2151, 22, 17, 19, 8, 139, 1900, 498, 679, 77, 6,
7691, 6, 21, 19, 0, 14, 10, 17, 1841, 178, 726, 14,
1452, 14, 423, 4963, 38, 1649, 2151, 126],
[ 6, 82, 428, 22, 11, 7, 19, 7, 5, 2527, 3704, 42,
1102, 59, 7, 12, 314, 14, 6, 7, 42, 12, 20, 562,
1995, 61, 574, 17, 7, 6937, 61, 61, 58, 0, 22, 28,
14, 2442, 27, 37, 5, 1116, 37, 21, 20, 547, 42, 6,
295, 6, 28, 12, 5, 12, 107, 110, 19, 17, 184, 14,
0, 12, 7, 7, 151, 64, 919, 8, 553, 0, 87, 176,
28, 8, 12, 1142, 17, 330, 22, 73, 13, 19, 53, 3950,
75, 2254, 14, 8, 6, 19, 608, 7, 8, 68, 44, 172,
1704, 5, 36, 7, 85, 85, 260, 7, 170, 19, 42, 21,
5, 21, 3005, 238, 5, 7, 3936, 85, 43, 6, 14, 7,
6, 12, 15, 5, 10, 170, 6, 115],
[ 12, 14, 658, 9, 293, 4525, 7, 146, 15, 19, 5272, 4891,
19, 38, 143, 86, 6, 12, 21, 1476, 74, 0, 57, 3439,
6, 138, 6, 0, 183, 11, 9, 9, 433, 10, 108, 183,
27, 1888, 200, 520, 58, 39, 300, 0, 116, 304, 347, 22,
6, 11, 1668, 293, 27, 3404, 6, 15, 15, 60, 10, 7,
10, 164, 7504, 23, 1939, 17, 7, 415, 14, 14, 24, 98,
4148, 124, 13, 248, 99, 8, 2224, 26, 1184, 534, 10, 0,
85, 14, 21, 91, 12, 7, 435, 119, 16, 10, 1252, 41,
29, 12, 14, 2072, 2110, 823, 20, 31, 627, 335, 1564, 183,
17, 159, 33, 40, 166, 38, 101, 619, 11, 12, 21, 1191,
126, 729, 22, 136, 140, 10, 12, 8],
[ 53, 11, 6, 154, 19, 139, 856, 9, 9, 728, 15, 127,
68, 13, 2853, 20, 7, 29, 0, 127, 9, 6, 80, 14,
7, 11, 21, 11, 13, 16, 48, 48, 6, 260, 10, 10,
23, 687, 4227, 0, 7181, 9, 4629, 19, 610, 3597, 1296, 108,
21, 101, 8, 24, 29, 57, 21, 7, 9, 1451, 2031, 13,
966, 5, 19, 220, 193, 9, 2680, 0, 11, 21, 12, 325,
529, 10, 4404, 20, 35, 205, 48, 1182, 6, 7725, 23, 12,
1448, 7, 23, 55, 3018, 75, 5, 231, 94, 90, 33, 101,
1590, 23, 27, 27, 23, 335, 898, 41, 0, 3590, 19, 120,
9, 40, 1462, 206, 15, 40, 86, 14, 449, 74, 23, 127,
3127, 5, 108, 15, 31, 107, 23, 37],
[ 30, 721, 7, 0, 509, 598, 10, 2428, 26, 24, 10, 11,
89, 5, 11, 16, 752, 381, 19, 0, 84, 11, 43, 22,
6062, 10, 62, 137, 5, 8, 1123, 358, 12, 0, 906, 29,
185, 5, 6, 7, 6, 195, 199, 68, 5, 33, 6, 1037,
442, 15, 1895, 11, 32, 1348, 31, 31, 285, 11, 8, 5,
83, 37, 37, 87, 205, 73, 2431, 11, 15, 279, 431, 35,
86, 213, 6, 17, 9, 39, 10, 2682, 47, 10, 185, 464,
10, 93, 307, 17, 6, 1029, 3, 37, 14, 15, 12, 35,
5, 93, 5150, 2835, 68, 13, 4833, 22, 7, 102, 85, 51,
5178, 341, 24, 11, 9, 71, 20, 59, 6, 25, 37, 12,
11, 288, 10, 12, 8, 19, 696, 13],
[ 43, 6, 13, 357, 10, 9, 178, 0, 5, 11, 115, 261,
251, 3, 31, 8, 0, 14, 1129, 25, 4351, 212, 3, 69,
167, 89, 618, 25, 3, 10, 577, 285, 845, 11, 8, 1413,
45, 3, 21, 25, 11, 6, 6, 503, 35, 565, 12, 6,
2099, 10, 733, 359, 2845, 18, 13, 54, 251, 450, 21, 3,
25, 370, 2300, 6, 22, 76, 24, 79, 9, 45, 15, 9,
20, 474, 759, 9, 896, 9, 13, 644, 15, 82, 246, 1221,
438, 16, 73, 9, 7, 616, 1, 10, 283, 7, 10, 9,
3, 3397, 705, 1140, 26, 5, 0, 36, 30, 6, 29, 20,
6, 6, 0, 13, 246, 6, 9, 38, 7, 26, 79, 53,
824, 9, 87, 488, 10, 64, 25, 5],
[ 3, 7, 5, 11, 63, 37, 33, 11, 3, 145, 17, 80,
11, 1, 13, 9, 5, 7, 30, 177, 0, 16, 1, 13,
30, 82, 5, 18, 1, 13, 10, 10, 25, 16, 58, 51,
18, 1, 15, 66, 824, 12, 7, 5, 9, 49, 54, 7,
11, 2796, 21, 1623, 5, 3, 5, 22, 11, 50, 23, 1,
66, 10, 50, 12, 221, 85, 12, 203, 3137, 18, 3, 140,
35, 14, 16, 7620, 765, 0, 19, 7, 9, 769, 6, 15,
664, 8, 196, 16, 62, 46, 1, 164, 16, 0, 458, 104,
1, 12, 12, 14, 70, 3, 5, 8, 18, 11, 618, 13,
12, 7, 208, 96, 6, 7, 3321, 4103, 29, 43, 82, 544,
1461, 731, 6, 25, 81, 53, 196, 3],
[ 1, 167, 3, 34, 254, 129, 1506, 17, 1, 35, 7, 43,
26, 1, 5, 0, 10, 13, 18, 5, 7582, 8, 1, 96,
18, 14, 3, 3, 1, 96, 0, 40, 66, 101, 83, 14,
3, 1, 9, 5, 14, 31, 16, 3, 262, 5, 600, 13,
52, 112, 9, 11, 3, 1, 3, 36, 490, 6, 287, 1,
5, 260, 10, 1051, 24, 129, 74, 6, 1105, 3, 1, 81,
9, 11, 274, 8, 12, 75, 12, 3, 13, 7, 7, 12,
52, 3, 0, 8, 13, 11, 1, 6, 8, 48, 137, 6,
1, 2016, 544, 12, 5, 1, 270, 3, 3, 13, 5, 15,
25, 7595, 38, 5, 7, 959, 5, 5, 16, 3, 6, 18,
14, 6750, 12, 43, 15, 54, 199, 1],
[ 1, 25, 1, 8, 69, 8, 3929, 9, 1, 9, 170, 3,
5, 1, 3, 14, 25, 5, 3, 3, 11, 3, 1, 5,
3, 7, 1, 1, 1, 5, 15, 46, 5, 8, 32, 7,
1, 1, 40, 3, 7, 906, 8, 1, 16, 10, 3, 5,
4016, 14, 104, 25, 1, 1, 1, 8, 5, 7, 13, 1,
3, 1975, 40, 5, 59, 38, 73, 7, 15, 1, 1, 15,
34, 263, 8, 3, 57, 12, 69, 1, 130, 488, 41, 0,
65, 1, 3, 3, 5, 13, 1, 11, 3, 1797, 54, 27,
1, 225, 3, 261, 3, 1, 26, 1, 1, 5, 3, 3,
18, 967, 65, 3, 52, 5, 670, 3, 8, 1, 7, 3,
7, 11, 141, 3, 7, 216, 39, 1],
[ 1, 5, 1, 74, 32, 15, 8, 54, 1, 90, 30, 1,
3, 1, 1, 21, 66, 3, 1, 1, 17, 1, 1, 3,
1, 13, 1, 1, 1, 3, 9, 6, 3, 95, 5, 13,
1, 1, 1844, 1, 2784, 15, 3, 1, 196, 16, 1, 3,
175, 11, 78, 66, 1, 1, 1, 3, 3, 16, 5, 1,
1, 11, 51, 3, 9, 40, 3, 13, 9, 1, 1, 7,
17, 0, 3, 1, 41, 347, 9, 1, 5, 1860, 76, 22,
103, 1, 1, 1, 3, 5, 1, 38, 1, 5, 76, 29,
1, 11, 1, 16, 1, 1, 5, 1, 1, 3, 1, 1,
3, 8, 36, 1, 45, 3, 20, 1, 3, 1, 164, 1,
60, 13, 6, 1, 461, 3, 3, 1],
[ 1, 10, 1, 54, 114, 9, 3, 22, 1, 1427, 18, 1,
1, 1, 1, 1117, 5, 1, 1, 1, 9, 1, 1, 1,
1, 5, 1, 1, 1, 1, 13, 7, 1, 250, 3, 5,
1, 1, 202, 1, 13, 7, 1, 1, 232, 8, 1, 1,
7, 14, 6, 5, 1, 1, 1, 1, 1, 8, 3, 1,
1, 13, 6, 1, 318, 586, 1, 5, 298, 1, 1, 898,
539, 5, 1, 1, 76, 65, 48, 1, 3, 105, 3, 36,
3, 1, 1, 1, 1, 10, 1, 13, 1, 3, 3, 16,
1, 26, 1, 8, 1, 1, 3, 1, 1, 1, 1, 1,
1, 3, 8, 1, 43, 1, 9, 1, 1, 1, 5, 1,
116, 5, 59, 1, 5, 1, 1, 1],
[ 1, 80, 1, 3, 5, 13, 1, 36, 1, 844, 3, 1,
1, 1, 1, 670, 3, 1, 1, 1, 1415, 1, 1, 1,
1, 3, 1, 1, 1, 1, 5, 39, 1, 82, 1, 3,
1, 1, 11, 1, 865, 13, 1, 1, 8, 3, 1, 1,
80, 9, 27, 3, 1, 1, 1, 1, 1, 3, 1, 1,
1, 5, 7, 1, 5, 7, 1, 3, 5, 1, 1, 6,
694, 61, 1, 1, 3, 18, 1791, 1, 1, 3, 1, 8,
1, 1, 1, 1, 1, 16, 1, 5, 1, 1, 1, 8,
1, 70, 1, 10, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 3, 1, 3, 1, 740, 1, 1, 1, 3, 1,
5, 3, 62, 1, 93, 1, 1, 1],
[ 1, 18, 1, 1, 3, 5, 1, 8, 1, 180, 1, 1,
1, 1, 1, 20, 1, 1, 1, 1, 4252, 1, 1, 1,
1, 1, 1, 1, 1, 1, 3, 9, 1, 6, 1, 1,
1, 1, 13, 1, 5, 5, 1, 1, 10, 1, 1, 1,
103, 332, 29, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 3, 16, 1, 38, 13, 1, 1, 3, 1, 1, 11,
10, 2045, 1, 1, 1, 3, 2040, 1, 1, 1, 1, 3,
1, 1, 1, 1, 1, 8, 1, 3, 1, 1, 1, 3,
1, 5, 1, 13, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
3, 1, 13, 1, 8, 1, 1, 1],
[ 1, 3, 1, 1, 1, 3, 1, 3, 1, 457, 1, 1,
1, 1, 1, 13, 1, 1, 1, 1, 2479, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 49, 1, 10, 1, 1,
1, 1, 96, 1, 3, 3, 1, 1, 13, 1, 1, 1,
76, 199, 113, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 890, 1, 3946, 5, 1, 1, 1, 1, 1, 23,
219, 734, 1, 1, 1, 1, 7, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1,
1, 3, 1, 5, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 5, 1, 1, 1, 1, 1,
1, 1, 5, 1, 9, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1,
1, 1, 1, 5, 1, 1, 1, 1, 7, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 5, 1, 260, 1, 1,
1, 1, 5, 1, 1, 1, 1, 1, 229, 1, 1, 1,
3, 868, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 8, 1, 8, 3, 1, 1, 1, 1, 1, 712,
11, 5, 1, 1, 1, 1, 812, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 166, 1, 1, 1, 1, 1,
1, 1, 3, 1, 549, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 13, 1, 1,
1, 1, 1, 3, 1, 1, 1, 1, 45, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 3, 1, 742, 1, 1,
1, 1, 3, 1, 1, 1, 1, 1, 234, 1, 1, 1,
1, 30, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 49,
16, 3, 1, 1, 1, 1, 15, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 15, 1, 1, 1, 1, 1,
1, 1, 1, 1, 6, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 18, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 1, 1,
1, 18, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5,
8, 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1,
1, 1, 1, 1, 7, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 11, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1,
1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3,
3, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1,
1, 1, 1, 1, 862, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 5, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 6, 1, 1, 1, 1, 1,
1, 1, 1, 1, 53, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 47, 1, 1, 1, 1, 1,
1, 1, 1, 1, 41, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 29, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1021, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 3938, 1, 1, 1, 1, 1,
1, 1, 1, 1, 3, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 7, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 5, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
29
128
4.2.2 Seq2Seq
主要功能:
- 接收输入/源句子
- 使用Encoder生成上下文向量
-
使用Decoder生成预测输出/目标句子 再看一下整体的模型
确定encoder和decoder每一层的数目、隐藏层。
下面是实现代码
# Seq2Seq
class Seq2Seq(nn.Module):
def __init__(self, encoder,decoder,device):
super(Seq2Seq,self).__init__()
self.encoder=encoder
self.decoder=decoder
self.device=device
assert encoder.hid_dim==decoder.hid_dim, "Hidden dimensions of encoder and decoder must be equal!"
assert encoder.n_layers==decoder.n_layers, "Num_Layers of encoder and decoder must be equal!"
def forward(self, src, trg, teacher_forcing_ratio=0.5):
# src = [src sent len, batch size]
# trg = [trg sent len, batch size]
# teacher_forcing_ratio是使用教师强制的概率
# 例如。如果teacher_forcing_ratio是0.75,我们75%的时间使用groundtruth输入
batch_size=trg.shape[1]
max_len=trg.shape[0]
trg_vocab_size=self.decoder.output_dim
# 创建输出张量,存储我们所有的预测
outputs=torch.zeros(max_len,batch_size,trg_vocab_size).to(self.device)
# 输入源语到encoder,然后获取最终的隐藏和单元状态
hidden, cell=self.encoder(src)
# decoder第一个输入的是句子的最开始的token,也就是那个标记,
input=trg[0,:]
# max_len就是行数
for t in range(1, max_len):
# 将输入,先前隐藏和前一个单元状态传递给Decoder
# 接收预测,来自Decoder下一个隐藏状态和下一个单元状态
output, hidden, cell=self.decoder(input,hidden,cell)
# 将我们的预测,输出放在我们的预测张量中
outputs[t]=output
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.max(1)[1]
input = (trg[t] if teacher_force else top1)
return outputs
可以看到在Seq2Seq这个模块的作用是整合encoder和decoder,将两者的数据打通,batch_size的大小就是在句子中一个词转为向量的长度"shape[1]",而max_len就是这个句子所包含的词的个数。这个在以后的模型中也会很重要,每次我都会注意一下。
4.2.3 encoder
请注意,我们只将第一层的隐藏状态作为输入传递给第二层,而不是单元状态。
下面重点来了,encoder有哪些参数其实在最开始的参数设置里面就可以看到这些。
- input_dim输入encoder的one-hot向量维度,这个和输入词汇大小一致
- emb_dim嵌入层的维度,这一层将one-hot向量转为密度向量
- hid_dim隐藏层和cell状态维度
- n_layersRNN的层数
- dropout是要使用的丢失量。这是一个防止过度拟合的正则化参数。
没什么特别的地方。
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super(Encoder, self).__init__()
self.input_dim=input_dim
self.emb_dim=emb_dim
self.hid_dim=hid_dim
self.n_layers=n_layers
self.dropout=dropout
self.embedding=nn.Embedding(input_dim,emb_dim)
self.rnn=nn.LSTM(emb_dim,hid_dim,n_layers,dropout=dropout)
self.dropout=nn.Dropout(dropout)
def forward(self, src):
embedded=self.dropout(self.embedding(src))
outputs, (hidden,cell)=self.rnn(embedded)
return hidden, cell
4.2.4 decoder
Decoder同样也是一个LSTM。
Decoder的初始隐藏和单元状态是我们的上下文向量,它们是来自同一层的Encoder的最终隐藏和单元状态。
接下来将隐藏状态传递给Linear层,预测目标序列下一个标记应该是什么。
Decoder的参数和Encoder类似,其中output_dim是将要输入到Decoder的one-hot向量。
- 在forward方法中,获取到了输入token、上一层的隐藏状态和单元状态。解压之后加入句子长度维度。
- 接下来与Encoder类似,传入嵌入层并使用dropout,然后将这批嵌入式令牌传递到具有先前隐藏和单元状态的RNN。这产生了一个输出(来自RNN顶层的隐藏状态),一个新的隐藏状态(每个层一个,堆叠在彼此之上)和一个新的单元状态(每层也有一个,堆叠在彼此的顶部))。
- 然后我们通过线性层传递输出(在除去句子长度维度之后)以接收我们的预测。然后我们返回预测,新的隐藏状态和新的单元状态。
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
super(Decoder, self).__init__()
self.emb_dim=emb_dim
self.hid_dim=hid_dim
self.output_dim=output_dim
self.n_layers=n_layers
self.dropout=dropout
self.embedding=nn.Embedding(output_dim, emb_dim)
self.rnn=nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
self.out=nn.Linear(hid_dim, output_dim)
self.dropout=nn.Dropout(dropout)
# 这里的hidden, cell是encoder输出的结果
def forward(self, input, hidden, cell):
input=input.unsqueeze(0)
embedded=self.dropout(self.embedding(input))
output, (hidden,cell)=self.rnn(embedded,(hidden,cell))
prediction=self.out(output.squeeze(0))
# 这里输出的prediction就是预测数据
return prediction, hidden, cell
4.3 训练模型
enc=Encoder(INPUT_DIM,ENC_EMB_DIM,HID_DIM,N_LAYERS,ENC_DROPOUT)
dec=Decoder(OUTPUT_DIM,DEC_EMB_DIM,HID_DIM,N_LAYERS,DEC_DROPOUT)
model=Seq2Seq(enc,dec,device).to(device)
model
Seq2Seq(
(encoder): Encoder(
(embedding): Embedding(7855, 256)
(rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
(dropout): Dropout(p=0.5)
)
(decoder): Decoder(
(embedding): Embedding(5893, 256)
(rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
(out): Linear(in_features=512, out_features=5893, bias=True)
(dropout): Dropout(p=0.5)
)
)
def init_weights(m):
for name, param in m.named_parameters():
nn.init.uniform_(param.data, -0.08, 0.08)
model.apply(init_weights)
Seq2Seq(
(encoder): Encoder(
(embedding): Embedding(7855, 256)
(rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
(dropout): Dropout(p=0.5)
)
(decoder): Decoder(
(embedding): Embedding(5893, 256)
(rnn): LSTM(256, 512, num_layers=4, dropout=0.5)
(out): Linear(in_features=512, out_features=5893, bias=True)
(dropout): Dropout(p=0.5)
)
)
def count_parameters(model):
# pytorch.numel返回矩阵内所有元素个数
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 22,304,005 trainable parameters
optimizer = optim.Adam(model.parameters())
# stoi允许访问包含单词及其索引的字典。
PAD_IDX = TRG.vocab.stoi['']
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
optimizer.zero_grad()
output = model(src, trg)
#trg = [trg sent len, batch size]
#output = [trg sent len, batch size, output dim]
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
#trg = [(trg sent len - 1) * batch size]
#output = [(trg sent len - 1) * batch size, output dim]
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output = model(src, trg, 0) #turn off teacher forcing
#trg = [trg sent len, batch size]
#output = [trg sent len, batch size, output dim]
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
#trg = [(trg sent len - 1) * batch size]
#output = [(trg sent len - 1) * batch size, output dim]
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
N_EPOCHS = 2
CLIP = 1
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
# valid_loss = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')
>Epoch: 01 | Time: 1m 57s
Train Loss: 4.981 | Train PPL: 145.638
Val. Loss: 4.938 | Val. PPL: 139.557
Epoch: 02 | Time: 1m 57s
Train Loss: 4.690 | Train PPL: 108.829
Val. Loss: 4.950 | Val. PPL: 141.169
Epoch: 03 | Time: 1m 56s
Train Loss: 4.421 | Train PPL: 83.212
Val. Loss: 4.642 | Val. PPL: 103.731
Epoch: 04 | Time: 1m 57s
Train Loss: 4.187 | Train PPL: 65.833
Val. Loss: 4.560 | Val. PPL: 95.608
Epoch: 05 | Time: 1m 57s
Train Loss: 4.045 | Train PPL: 57.138
Val. Loss: 4.429 | Val. PPL: 83.808
Epoch: 06 | Time: 1m 56s
Train Loss: 3.939 | Train PPL: 51.373
Val. Loss: 4.400 | Val. PPL: 81.460
Epoch: 07 | Time: 1m 56s
Train Loss: 3.862 | Train PPL: 47.579
Val. Loss: 4.370 | Val. PPL: 79.046
Epoch: 08 | Time: 1m 57s
Train Loss: 3.755 | Train PPL: 42.738
Val. Loss: 4.369 | Val. PPL: 78.992
Epoch: 09 | Time: 1m 56s
Train Loss: 3.672 | Train PPL: 39.322
Val. Loss: 4.223 | Val. PPL: 68.248
Epoch: 10 | Time: 1m 57s
Train Loss: 3.622 | Train PPL: 37.402
Val. Loss: 4.201 | Val. PPL: 66.773