pytorch笔记:transformer 源码

来自B站视频,API查阅,TORCH.NN

  • seq2seq 可以是 CNN,RNN,transformer
    pytorch笔记:transformer 源码_第1张图片
  • nn.Transformer 关键源码:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first,
                                                    **factory_kwargs)
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)



decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first, norm_first,
                                                    **factory_kwargs)
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)



memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                     tgt_key_padding_mask=tgt_key_padding_mask,
                     memory_key_padding_mask=memory_key_padding_mask)
return output
  • src_mask 是 padding 的 mask,tgt_mask 是为了 mask 掉目标句子的后续,memory_mask 是 decoder 第二个 mha 的 mask
  • The Annotated Transformer 有详细讲解

你可能感兴趣的:(学习笔记,pytorch,transformer,笔记)