The Annotated Transformer 应该是我见过最贴心的‘Attention is All You Need’的复现了。看网页链接像是哈佛大学复现的,质量应该还不错,于是就照着代码按顺序ctrl + c +v了一遍。其实在github上也有代码可以直接下载,只不过是.ipynb格式的。
在调试代码的过程中,遇到了一些问题,在这里记录一下。
1 环境安装
作者没有说明每个依赖库的版本,以下是我个人的版本,可以参考。
python==3.8.8
torch==1.9.0
numpy==1.20.1
matplotlib==3.3.4
spacy==2.2.2
torchtext==0.6.0
numpy和matplotlib的版本影响应该不大;python的版本影响大不大不知道;torch的版本有点影响,这个版本会导致一个小错误,不过可以被解决;spacy和torchtext的版本影响很大!那是2018年发布的博客,torch、spacy和torchtext的版本应该比较低。博客中还提到要安装seaborn,我没有装,好像没影响。
2 遇到的问题
问题1:
在执行以下代码的时候,
python -m spacy download en
python -m spacy download de
出现网络链接的错误:
requests.exceptions.ConnectionError: HTTPSConnectionPool(host=‘raw.githubusercontent.com’, port=443): Max retries exceeded with url: /explosion/spacy-models/master/shortcuts-v2.json (Caused by NewConnectionError(’
参考这个方法,安装了en_core_web_sm
和de_core_news_sm
,先手动下载安装包(百度云盘,提取码:0cic),再用pip安装。
pip install en_core_web_sm-2.2.5.tar.gz
pip install de_core_news_sm-2.2.5.tar.gz
相应的,代码要改一下,把
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
改成
import en_core_web_sm
import de_core_news_sm
if True:
spacy_de = de_core_news_sm.load()
spacy_en = en_core_web_sm.load()
问题2
torchtext版本太高导致from torchtext import data, datasets
的data.Field
和datasets.IWSLT
不存在,把版本降到0.6.0就存在了。
问题3
代码执行到train, val, test = datasets.IWSLT.splits(...)
时,程序会下载数据,也报网络链接的错。经过debug发现,程序先在本地找.data/iwslt/de-gn.tgz
这个文件,找不到才去下载。所以,可以先把.data/iwslt/de-gn.tgz
文件准备好就行了。那么,这是个什么文件呢?这个文件来自WIT3,得在google drive下载,下载下来的文件名是2016-01.tgz
,解压后在里面找到一个叫de-gn.tgz
的文件(得翻几层文件夹),放在.data/iwslt/
目录下就可以了,注意这是一个相对路径。
问题4
OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.
解决方法:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
问题5
IndexError: invalid index of a 0-dim tensor. Use tensor.item()
in Python or tensor.item
in C++ to convert a 0-dim tensor to a number
pytorch版本导致的问题,解决方法:
把SimpleLossCompute
的最后一行的return loss.data[0] * norm
改成return loss.data.item() * norm
3 整理代码
把一些零散的代码粘贴到一个函数里:
def train_on_cpu():
"""Train the model on cpu.
"""
# For data loading.
from torchtext import data, datasets
import en_core_web_sm
import de_core_news_sm
spacy_de = de_core_news_sm.load()
spacy_en = en_core_web_sm.load()
def tokenize_de(text):
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
BOS_WORD = ''
EOS_WORD = ''
BLANK_WORD = ""
SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD)
TGT = data.Field(tokenize=tokenize_en, init_token=BOS_WORD,
eos_token=EOS_WORD, pad_token=BLANK_WORD)
MAX_LEN = 100
train, val, test = datasets.IWSLT.splits(
exts=('.de', '.en'), fields=(SRC, TGT),
filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and
len(vars(x)['trg']) <= MAX_LEN)
MIN_FREQ = 2
SRC.build_vocab(train.src, min_freq=MIN_FREQ)
TGT.build_vocab(train.trg, min_freq=MIN_FREQ)
# Make a model and data iterators
pad_idx = TGT.vocab.stoi["" ]
model = make_model(len(SRC.vocab), len(TGT.vocab), N=6)
criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)
BATCH_SIZE = 8
train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=0,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=True)
valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=0,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=False)
# Train
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
for epoch in range(10):
model.train()
run_epoch((rebatch(pad_idx, b) for b in train_iter),
model,
SimpleLossCompute(model.generator, criterion, model_opt))
model.eval()
loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter),
model,
SimpleLossCompute(model.generator, criterion, None))
print(loss)
for i, batch in enumerate(valid_iter):
src = batch.src.transpose(0, 1)[:1]
src_mask = (src != SRC.vocab.stoi["" ]).unsqueeze(-2)
out = greedy_decode(model, src, src_mask,
max_len=60, start_symbol=TGT.vocab.stoi[""])
print("Translation:", end="\t")
for i in range(1, out.size(1)):
sym = TGT.vocab.itos[out[0, i]]
if sym == "": break
print(sym, end=" ")
print()
print("Target:", end="\t")
for i in range(1, batch.trg.size(0)):
sym = TGT.vocab.itos[batch.trg.data[i, 0]]
if sym == "": break
print(sym, end=" ")
print()
break
在脚本末尾写个程序入口:
if __name__ == '__main__':
train_on_cpu() # I have no GPU
最后run一下,就可以看到以下舒心的画面了:
Epoch Step: 1 Loss: 9.118329 Tokens per Sec: 11.898214
Epoch Step: 51 Loss: 8.626973 Tokens per Sec: 23.280647
Epoch Step: 101 Loss: 7.953571 Tokens per Sec: 20.981972
完整的代码可以在github下载。