[Pytorch] Sequence-to-Sequence Decoder 代码学习

虽然对 encoder-decoder 框架的了解已经很多了,但是从未实现过,可谓是“最熟悉的陌生人了”。近期,由于研究的需要,故而参照 github 上某开源项目(pytorch-seq2seq),实现了一个句法分析系统。本文,来学习一下实现的 decoder 部分的代码。

首先是import 部分的代码

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from attention import Attention
from baseRNN import BaseRNN

if torch.cuda.is_available():
    import torch.cuda as device
else:
    import torch as device

在 import部分:首先导入 numpy以及 torch 中需要使用到的模块。除了公共包,此处导入了名叫 Attention以及 BaseRNN 的模块,其中BaseRNN 为对 torch.nn.rnn 模块的一个wrapper, Attention的机制也是在 seq2seq 中一个很重要的部分,用于获取解码时对于解码中某一时刻最为 care 的信息,很简短的代码,留待以后补充。


init部分:

def __init__(self, vocab_size, max_len, input_size, hidden_size,
                 sos_id, eos_id,
                 n_layers=1, rnn_cell='gru', bidirectional=False,
                 input_dropout_p=0, dropout_p=0, use_attention=False):
        super(DecoderRNN, self).__init__(vocab_size, max_len, input_size, hidden_size,input_dropout_p, dropout_p,n_layers, rnn_cell)

        self.bidirectional_encoder = bidirectional
        self.rnn = self.rnn_cell(input_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p)

        self.output_size = vocab_size
        self.max_length = max_len
        self.use_attention = use_attention
        self.eos_id = eos_id
        self.sos_id = sos_id

        self.init_input = None

        self.embedding = nn.Embedding(self.output_size, self.input_size)
        if use_attention:
            self.attention = Attention(self.hidden_size)

        self.fflayer = nn.Linear(self.hidden_size, self.output_size)

以上是一系列解码过程中需要使用到的参数。

  • bidirectional: 指明 encoder 端的输入是否为 bidirectional,用于初始化 encoder hidden 
  • rnn: decoder 端为一个 rnn
  • output_size: decoder端 output 的“词表”大小
  • max_length: 最长解码长度
  • use_attention: 是否在解码端使用注意力机制构建 feature 表示
  • eos_id: 辅助用于判断解码终止
  • sos_id: 辅助用于解码端的第一个输入
  • init_input: 目前没什么用
  • embedding: 解码端 output的词表 embedding
  • fflayer: 在解码时提供计算 output 的

以下则进入我们解码时的每一步时执行的操作,即为 forward one step

def forward_step(self, input_var, hidden, encoder_outputs, function):
        """
        Args:
            input_var: input token ids
            hidden: last hidden state
            encoder_outputs: encoder-layer output
            function: probs function, default is F.log_softmax
        Return:
            the softmax output, the hidden state save, and the attention value
        """
        batch_size = input_var.size(0)
        output_size = input_var.size(1)
        embedded = self.embedding(input_var)
        embedded = self.input_dropout(embedded)

        output, hidden = self.rnn(embedded, hidden)

        attn = None
        if self.use_attention:
            output, attn = self.attention(output, encoder_outputs)

        predicted_softmax = function(self.fflayer(output.contiguous().view(-1, self.hidden_size))).view(batch_size,output_size, -1)
        return predicted_softmax, hidden, attn

对于 forward step 的含义:即为decoder 端每一次的 forward one time.

在其中所做的操作为:使用 last step的decoder rnn 的 hidden与outputs symbols[inputs]作为inputs,计算 decoder 端 rnn 的 cur step 的 decoder hidden 以及 decoder output。使用 decoder rnn output(如果使用 attention,则使用当前的 decoder rnn output 与 encoder outputs 计算一个 context后 拼接在一起)构建fflayer 的 inputs,然后得到 decoder 端cur step的 predict-softmax,hidden 以及 attention。

来了解一下其中的参数:

inputs:操作时,由于操作的对象为一个 batch。所以其中 input_var为当前这一步整个 batch 即将 feed 的 inputs,也是last step 的 output symbol,dimension:[batch,1]。

hidden: last step 的 decoder hidden,dimension:[1,50,hidden-dims]

encoder_outpus: encoder端的 outputs,dimension:[batch,seq_len,hidden-dims]

funcion: 用于计算得到 output probs 的函数,通常为 F.log_softmax

然后进入主 forward 进程:

def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None,
                function=F.log_softmax, teacher_forcing_ratio=0):
        ret_dict = dict()
        if self.use_attention:
            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()

        inputs, batch_size, max_length = self._validate_args(inputs, encoder_hidden, encoder_outputs,function, teacher_forcing_ratio)
        decoder_hidden = self._init_state(encoder_hidden)
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

        decoder_outputs = []
        sequence_symbols = []
        lengths = np.array([max_length] * batch_size)

        def decode(step, step_output, step_attn):
            decoder_outputs.append(step_output)
            if self.use_attention:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)
            symbols = decoder_outputs[-1].topk(1)[1]
            sequence_symbols.append(symbols)
            eos_batches = symbols.data.eq(self.eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > step) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)
            return symbols
        if use_teacher_forcing:
            decoder_input = inputs[:, :-1]
            decoder_output, decoder_hidden, attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs,function=function)
            for di in range(decoder_output.size(1)):
                step_output = decoder_output[:, di, :]
                if attn is not None:
                    step_attn = attn[:, di, :]
                else:
                    step_attn = None
                decode(di, step_output, step_attn)
        else:
            decoder_input = inputs[:, 0].unsqueeze(1)
            for di in range(max_length):
                decoder_output, decoder_hidden, step_attn = self.forward_step(decoder_input, decoder_hidden,encoder_outputs, function=function)
                step_output = decoder_output.squeeze(1)
                symbols = decode(di, step_output, step_attn)
                decoder_input = symbols

        ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
        ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()
        return decoder_outputs, decoder_hidden, ret_dict

其主 forward 进程主要是将 forward_step 步骤逐步进行,然后处理其输出。


给出全部的代码:

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from attention import Attention
from baseRNN import BaseRNN

if torch.cuda.is_available():
    import torch.cuda as device
else:
    import torch as device


class DecoderRNN(BaseRNN):
    KEY_ATTN_SCORE = 'attention_score'
    KEY_LENGTH = 'length'
    KEY_SEQUENCE = 'sequence'

    def __init__(self, vocab_size, max_len, input_size, hidden_size,
                 sos_id, eos_id,
                 n_layers=1, rnn_cell='gru', bidirectional=False,
                 input_dropout_p=0, dropout_p=0, use_attention=False):
        super(DecoderRNN, self).__init__(vocab_size, max_len, input_size, hidden_size,
                                         input_dropout_p, dropout_p,
                                         n_layers, rnn_cell)

        self.bidirectional_encoder = bidirectional
        self.rnn = self.rnn_cell(input_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p)

        self.output_size = vocab_size
        self.max_length = max_len
        self.use_attention = use_attention
        self.eos_id = eos_id
        self.sos_id = sos_id

        self.init_input = None

        self.embedding = nn.Embedding(self.output_size, self.input_size)
        if use_attention:
            self.attention = Attention(self.hidden_size)

        self.fflayer = nn.Linear(self.hidden_size, self.output_size)

    def forward_step(self, input_var, hidden, encoder_outputs, function):
        """
        Args:
            input_var: input token ids
            hidden: last hidden state
            encoder_outputs: encoder-layer output
            function: probs function, default is F.log_softmax
        Return:
            the softmax output, the hidden state save, and the attention value
        """
        batch_size = input_var.size(0)
        output_size = input_var.size(1)
        embedded = self.embedding(input_var)
        embedded = self.input_dropout(embedded)

        output, hidden = self.rnn(embedded, hidden)

        attn = None
        if self.use_attention:
            output, attn = self.attention(output, encoder_outputs)

        predicted_softmax = function(self.fflayer(output.contiguous().view(-1, self.hidden_size))).view(batch_size,
                                                                                                        output_size, -1)
        return predicted_softmax, hidden, attn

    def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None,
                function=F.log_softmax, teacher_forcing_ratio=0):
        """
        Args:
            inputs:target_variable when training else None


        """

        ret_dict = dict()
        if self.use_attention:
            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()

        inputs, batch_size, max_length = self._validate_args(inputs, encoder_hidden, encoder_outputs,
                                                             function, teacher_forcing_ratio)
        decoder_hidden = self._init_state(encoder_hidden)
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

        decoder_outputs = []
        sequence_symbols = []
        lengths = np.array([max_length] * batch_size)

        def decode(step, step_output, step_attn):
            decoder_outputs.append(step_output)
            if self.use_attention:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)
            symbols = decoder_outputs[-1].topk(1)[1]
            sequence_symbols.append(symbols)
            eos_batches = symbols.data.eq(self.eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > step) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)
            return symbols

        # Manual unrolling is used to support random teacher forcing.
        # If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph
        if use_teacher_forcing:
            decoder_input = inputs[:, :-1]
            decoder_output, decoder_hidden, attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs,
                                                                     function=function)
            for di in range(decoder_output.size(1)):
                step_output = decoder_output[:, di, :]
                if attn is not None:
                    step_attn = attn[:, di, :]
                else:
                    step_attn = None
                decode(di, step_output, step_attn)
        else:
            decoder_input = inputs[:, 0].unsqueeze(1)
            for di in range(max_length):
                decoder_output, decoder_hidden, step_attn = self.forward_step(decoder_input, decoder_hidden,
                                                                              encoder_outputs, function=function)
                step_output = decoder_output.squeeze(1)
                symbols = decode(di, step_output, step_attn)
                decoder_input = symbols

        ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
        ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()
        return decoder_outputs, decoder_hidden, ret_dict

    def _init_state(self, encoder_hidden):
        """ Initialize the encoder hidden state. """
        if encoder_hidden is None:
            return None
        if isinstance(encoder_hidden, tuple):
            encoder_hidden = tuple([self._cat_directions(h) for h in encoder_hidden])
        else:
            encoder_hidden = self._cat_directions(encoder_hidden)
        return encoder_hidden

    def _cat_directions(self, h):
        """ If the encoder is bidirectional, do the following transformation.
            (#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size)
        """
        if self.bidirectional_encoder:
            h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
        return h

    def _validate_args(self, inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio):
        """
        Args:
            inputs: decoder inputs
            encoder_hidden: encoder hidden state
            encoder_outputs: if use attention, the outputs must be feed
            function: softmax
            teacher_forcing_ratio: feed the decoder with oracle if the teacher_forcing_ratio > 0
        Return:
            inputs: if inputs is None, initial with [self.sos_id]*batch_size
            batch_size: batch
            max_length: decoder length
        """
        if self.use_attention:
            if encoder_outputs is None:
                raise ValueError("Argument encoder_outputs cannot be None when attention is used.")

        # inference batch size
        if inputs is None and encoder_hidden is None:
            batch_size = 1
        else:
            if inputs is not None:
                batch_size = inputs.size(0)
            else:
                if self.rnn_cell is nn.LSTM:
                    batch_size = encoder_hidden[0].size(1)
                elif self.rnn_cell is nn.GRU:
                    batch_size = encoder_hidden.size(1)

        # set default input and max decoding length
        if inputs is None:
            if teacher_forcing_ratio > 0:
                raise ValueError("Teacher forcing has to be disabled (set 0) when no inputs is provided.")
            inputs = Variable(torch.LongTensor([self.sos_id] * batch_size),
                              volatile=True).view(batch_size, 1)
            if torch.cuda.is_available():
                inputs = inputs.cuda()
            max_length = self.max_length
        else:
            max_length = inputs.size(1) - 1  # minus the start of sequence symbol

        return inputs, batch_size, max_length


你可能感兴趣的:(NLP,Coding)