Pytorch学习记录-attention的可视化

Pytorch学习记录-torchtext和Pytorch的实例4

0. PyTorch Seq2Seq项目介绍

在完成基本的torchtext之后,找到了这个教程,《基于Pytorch和torchtext来理解和实现seq2seq模型》。
这个项目主要包括了6个子项目

  1. 使用神经网络训练Seq2Seq
  2. 使用RNN encoder-decoder训练短语表示用于统计机器翻译
  3. 使用共同学习完成NMT的堆砌和翻译
  4. 打包填充序列、掩码和推理
  5. 卷积Seq2Seq
  6. Transformer

4. 打包填充序列、掩码和推理

教程基于之前的模型增加了打包填充序列、掩码。

  • 打包填充序列被用于让RNN在Encoder部分略过填充的token。
  • 掩码能够明确强制模型忽略某些值,例如对填充元素的注意。这两种技术都常用于NLP。

这个教程同样也会关注模型的推理,给定句子,查看翻译结果。找出究竟注意力机制关注哪些词。

4.1 引入库和数据预处理

4.2 构建模型

4.3 训练模型

4.4 推断

现在可以使用训练的模型生成翻译了。这里要注意的是attention的可视化实现。
这个模型是一个比较简陋的,因为只进行了10轮,同时隐藏维度很小,在原论文中使用的是1000的隐藏维度并且训练了4天……
在这里,translate_sentence方法要做以下事情:

  • 确定模型是在eval模式,这样可以进行推断
  • 对输入句子进行分词
  • 对分词结果全部小写,同时加入句子的开始和结束标签
  • 使用我们的词汇表通过将它们转换为索引来对我们的标记进行数字化
  • 获取句子的长度并转为tensor
  • 将数字化的句子放入tensor中,加入批维度
  • 放入模型中,确保trg是空集,teacher force参数为0
  • 使用argmax获取预测结果得分最高的值
  • 将结果的index转为string
  • 作为我们输出中的第一个元素,我们模型中的注意力张量都是零,我们在返回之前修剪它们
def translate_sentence(model,sentence):
    model.eval()
    tokenized=tokenize_de(sentence)
    tokenized=['']+[t.lower() for t in tokenized]+['']
    numericalized=[SRC.vocab.stoi[t] for t in tokenized]
    sentence_length=torch.LongTensor([len(numericalized)]).to(device)
    tensor = torch.LongTensor(numericalized).unsqueeze(1).to(device) 
    translation_tensor_logits, attention = model(tensor, sentence_length, None, 0) 
    translation_tensor = torch.argmax(translation_tensor_logits.squeeze(1), 1)
    translation = [TRG.vocab.itos[t] for t in translation_tensor]
    translation, attention = translation[1:], attention[1:]
    return translation, attention
def display_attention(candidate,translation,attention):
    fig=plt.figure(figsize=(10,10))
    ax=fig.add_subplot(111)
    attention=attention.squeeze(1).cpu().detach().numpy()
    cax=ax.matshow(attention,cmap='bone')
    ax.tick_params(labelsize=15)
    ax.set_xticklabels(['']+['']+[t.lower() for t in tokenize_de(candidate)]+[''])
    ax.set_yticklabels(['']+translation)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()
    plt.close()

现在,我们将从我们的数据集中获取一些翻译,看看我们的模型有多好。注意,我们将在这里挑选示例,以便为我们提供一些有趣的内容,但您可以随意更改example_idx值以查看不同的示例。
首先,我们将从数据集中获取源和目标。

example_idx=40
src=' '.join(vars(train_data.examples[example_idx])['src'])
trg=' '.join(vars(train_data.examples[example_idx])['trg'])
print(f'src={src}')
print(f'trg={trg}')
src=zwei kleinkinder im freien auf dem gras .
trg=two young toddlers outside on the grass .

然后我们将使用translate_sentence函数来获得预测的翻译和注意力。我们通过在x轴上具有源句子和在y轴上具有预测的平移来图形化地示出。两个单词交叉处的正方形越浅,模型在翻译目标单词时对该源词的关注就越大。
下面是100%正确翻译模型的示例。请注意,当将zwei正确地翻译成两个时,它似乎根本没有注意到zwei。然而,当将männerstehen翻译成男人时,它已经成功地引起了注意。

translation, attention = translate_sentence(model, src)

print(f'predicted trg = {translation}')

display_attention(src, translation, attention)
predicted trg = ['a', '', 'worker', 'is', '', 'a', '', '.']
Pytorch学习记录-attention的可视化_第1张图片
image.png

我们又换了一个


Pytorch学习记录-attention的可视化_第2张图片
image.png

你可能感兴趣的:(Pytorch学习记录-attention的可视化)