学习目标
我们会使用 torchtext 来创建vocabulary, 然后把数据读成batch的格式。请大家自行阅读README来学习torchtext。
import torchtext
from torchtext.vocab import Vectors
import torch
import numpy as np
import random
#cuda是否可用
EXIST_CUDA=torch.cuda.is_available()
#为了保证结果可以复现,经常把random seed固定为一个值
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if EXIST_CUDA:
torch.cuda.manual_seed(1)
BATCH_SIZE=32 # 一个batch里有32个句子
EMBEDDING_SIZE=100 #把一个单词embedding为100维
MAX_VOCAB_SIZE=50000 #高频词表最大容量
TEXT=torchtext.data.Field(lower=True)
#创建数据集
train,val,test=torchtext.datasets.LanguageModelingDataset.splits(
path="datas/3/text8/",
train="text8.train.txt",
validation="text8.dev.txt",
test="text8.test.txt",
text_field=TEXT
)
#创建高频词表
TEXT.build_vocab(train,max_size=MAX_VOCAB_SIZE)
#构建每一个batch,上面定义一个batch有32条句子
device=torch.device("cuda" if EXIST_CUDA else "cpu")
train_iter,val_iter,test_iter=torchtext.data.BPTTIterator.splits(
(train,val,test),
batch_size=BATCH_SIZE,
device=device,
bptt_len=50,
repeat=False,
shuffle=True
)
模型的输入是一串文字,模型的输出也是一串文字,他们之间相差一个位置,因为语言模型的目标是根据之前的单词预测下一个单词。
import torch.nn as nn
class RNNModel(nn.Module):
def __init__(self,vocab_size,embed_size,hidden_size):
super(RNNModel,self).__init__()
self.embed=nn.Embedding(vocab_size,embed_size)
self.lstm=nn.LSTM(embed_size,hidden_size)
self.linear=nn.Linear(hidden_size,vocab_size)
self.hidden_size=hidden_size
def forward(self,text,hidden):
#forward pass
# the shape of text:seq_length * batch_size
emb=self.embed(text)#seq_length * batch_size * embed_size
output,hidden=self.lstm(emb,hidden)
out_vocab=self.linear(output.view(-1,output.shape[2]))
out_vocab=out_vocab.view(output.size(0),output.size(1),out_vocab.size(-1))
return out_vocab,hidden
def init_hidden(self,bsz,requires_grad=True):
weight=next(self.parameters())
return (weight.new_zeros((1,bsz,self.hidden_size),requires_grad=True),
weight.new_zeros((1,bsz,self.hidden_size),requires_grad=True))
#初始化模型
model=RNNModel(vocab_size=len(TEXT.vocab),embed_size=EMBEDDING_SIZE,hidden_size=100)
if EXIST_CUDA:
model=model.to(device)
loss_fn=nn.CrossEntropyLoss()
learning_rate=0.001
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
NUM_EPOCHS=1 # 全部训练数据训练的轮次
VOCAB_SIZE=len(TEXT.vocab)
GRAD_CLIP=5.0
scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,0.5)
val_losses=[10]
#评估
def evaluate(model,data):
model.eval()
it=iter(data)
total_loss=0.
total_count=0.
with torch.no_grad():
hidden=model.init_hidden(BATCH_SIZE,requires_grad=False)
for i,batch in enumerate(it):
data,target=batch.text,batch.target
hidden=repackage_hidden(hidden)
output,hidden=model(data,hidden)
loss=loss_fn(output.view(-1,VOCAB_SIZE),target.view(-1))
total_loss+=loss.item()*np.multiply(*data.size())
total_count+=np.multiply(*data.size())
loss=total_loss/total_count
model.train()
return loss
#我们需要定义下面的一个function,帮助我们把一个hidden state和计算图之前的历史分离
def repackage_hidden(h):
if isinstance(h,torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
for epoch in range(NUM_EPOCHS):
model.train()
it=iter(train_iter)
hidden=model.init_hidden(BATCH_SIZE)
for i,batch in enumerate(it):
data,target=batch.text,batch.target
hidden=repackage_hidden(hidden)
output,hidden=model(data,hidden)
loss=loss_fn(output.view(-1,VOCAB_SIZE),target.view(-1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(),GRAD_CLIP)#防止梯度爆炸
optimizer.step()
if i%100==0:
print(i," loss:",loss.item())
#每1900次判断下验证集的loss,如果比之前最小的还要小的话,保存这个模型的参数
if i%1900==0:
val_loss=evaluate(model,val_iter)#是用验证集获得当前模型下的loss
if val_loss < min(val_losses):
print("best model saved to 03lstm.pth")
torch.save(model.state_dict(),"model/3/03lstm.pth")
#发现模型的loss降不下来,因此通过上面定义的方法减小learning_rate
else:
#降低learning_rate
print("learning_rate decay")
scheduler.step()
val_losses.append(val_loss)
0 loss: 5.700478553771973
best model saved to 03lstm.pth
100 loss: 5.474410533905029
200 loss: 5.539557456970215
300 loss: 5.759274482727051
400 loss: 5.686248779296875
500 loss: 5.632628917694092
600 loss: 5.481709003448486
700 loss: 5.584092617034912
800 loss: 5.7943501472473145
900 loss: 5.541199207305908
1000 loss: 5.437957763671875
1100 loss: 5.763401031494141
1200 loss: 5.438232898712158
1300 loss: 5.728765487670898
1400 loss: 5.6812005043029785
1500 loss: 5.331437587738037
1600 loss: 5.531680107116699
1700 loss: 5.482674598693848
1800 loss: 5.578347206115723
1900 loss: 5.531027317047119
learning_rate decay
2000 loss: 5.6362833976745605
2100 loss: 5.604646682739258
2200 loss: 5.438443183898926
2300 loss: 5.304264068603516
2400 loss: 5.690061092376709
2500 loss: 5.453220367431641
2600 loss: 5.441572189331055
2700 loss: 5.776185512542725
2800 loss: 5.629850387573242
2900 loss: 5.619969367980957
3000 loss: 5.5757222175598145
3100 loss: 5.772238731384277
3200 loss: 5.692197322845459
3300 loss: 5.51469612121582
3400 loss: 5.358908176422119
3500 loss: 5.429351806640625
3600 loss: 5.5990190505981445
3700 loss: 5.883382797241211
3800 loss: 5.582748889923096
learning_rate decay
3900 loss: 5.5894575119018555
4000 loss: 5.436612606048584
4100 loss: 5.603799819946289
4200 loss: 5.246464729309082
4300 loss: 5.7568840980529785
4400 loss: 5.332048416137695
4500 loss: 5.250970840454102
4600 loss: 5.414524555206299
4700 loss: 5.852789878845215
4800 loss: 5.710803031921387
4900 loss: 5.4412336349487305
5000 loss: 5.87037467956543
5100 loss: 5.393296718597412
5200 loss: 5.630399703979492
5300 loss: 5.1652703285217285
5400 loss: 5.573890209197998
5500 loss: 5.438013076782227
5600 loss: 5.229452610015869
5700 loss: 5.355339527130127
best model saved to 03lstm.pth
5800 loss: 5.6232757568359375
5900 loss: 5.606210708618164
6000 loss: 5.606449604034424
6100 loss: 5.649041652679443
6200 loss: 5.638283729553223
6300 loss: 5.740434169769287
6400 loss: 5.819083213806152
6500 loss: 5.349177837371826
6600 loss: 5.7113494873046875
6700 loss: 5.720933437347412
6800 loss: 5.368650913238525
6900 loss: 5.252537250518799
7000 loss: 5.532567977905273
7100 loss: 5.527868270874023
7200 loss: 5.364249229431152
7300 loss: 5.634284496307373
7400 loss: 5.607549667358398
7500 loss: 5.378734111785889
7600 loss: 5.748443126678467
best model saved to 03lstm.pth
7700 loss: 5.56899356842041
7800 loss: 5.3647565841674805
7900 loss: 5.424122333526611
8000 loss: 5.5352325439453125
8100 loss: 5.26278018951416
8200 loss: 5.719631195068359
8300 loss: 5.376105308532715
8400 loss: 5.5696845054626465
8500 loss: 5.4810261726379395
8600 loss: 5.4345703125
8700 loss: 5.505951404571533
8800 loss: 5.745686054229736
8900 loss: 5.7545599937438965
9000 loss: 5.610304355621338
9100 loss: 5.596979141235352
9200 loss: 5.378000259399414
9300 loss: 5.5428948402404785
9400 loss: 5.66567325592041
9500 loss: 5.3651909828186035
best model saved to 03lstm.pth
best_model=RNNModel(vocab_size=len(TEXT.vocab),embed_size=EMBEDDING_SIZE,hidden_size=100)
if EXIST_CUDA:
best_model=best_model.to(device)
best_model.load_state_dict(torch.load("model/3/03lstm.pth"))
使用训练好的模型生成句子
hidden = best_model.init_hidden(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input = torch.randint(VOCAB_SIZE, (1, 1), dtype=torch.long).to(device)
words = []
for i in range(100):
output, hidden = best_model(input, hidden)
word_weights = output.squeeze().exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.fill_(word_idx)
word = TEXT.vocab.itos[word_idx]
words.append(word)
print(" ".join(words))
unfair trees have some perfect secret use only the boy period per years the density of the pyruvate of steam bass operators often odd this article superior s point can transform its superpower state of parent and the responsibility for each other to who is identical to the royal society against adventurers and is an thrust to form decree of fundamental musicians are cross climate to poland than to be blinded meters and hence for ftp breeding to be defined is presenting strong consequences of the confession include java problem solomon minuet razor algorithm or opponents dispute by mina
使用训练好的模型在测试数据上计算perplexity
test_loss = evaluate(best_model, test_iter)
print("perplexity: ", np.exp(test_loss))
perplexity: 261.8622638740215