pointer network 的pytorch实现

class GetSpanStartEnd(nn.Module):
    # supports MLP attention and GRU for pointer network updating
    def __init__(self, x_size, h_size, opt, do_indep_attn=True, attn_type="Bilinear", do_ptr_update=True):
        super(GetSpanStartEnd, self).__init__()

        self.attn  = BilinearSeqAttn(x_size, h_size, opt)
        self.attn2 = BilinearSeqAttn(x_size, h_size, opt) if do_indep_attn else None

        self.rnn = nn.GRUCell(x_size, h_size) if do_ptr_update else None

    # x -- doc_hiddens [10,384,250]
    # h0 -- question_avg_hidden [10,125]
    # x_mask [10,384]
    def forward(self, x, h0, x_mask):
        """
        x = [batch, len, x_hidden_size]
        h0 = [batch, h_size]
        x_mask = [batch, len]
        """
        start_scores = self.attn(x, h0, x_mask) # [10,384]
        # start_scores [batch, len]

        if self.rnn is not None:
            ptr_net_in = torch.bmm(F.softmax(start_scores, dim=1).unsqueeze(1), x).squeeze(1) # [10,250]
            ptr_net_in = dropout(ptr_net_in, p=my_dropout_p, training=self.training)
            h0 = dropout(h0, p=my_dropout_p, training=self.training)
            h1 = self.rnn(ptr_net_in, h0) # [10,125]
            # h1 same size as h0
        else:
            h1 = h0

        end_scores = self.attn(x, h1, x_mask) if self.attn2 is None else\
                     self.attn2(x, h1, x_mask)
        # end_scores = batch * len
        return start_scores, end_scores # [10,384]
class BilinearSeqAttn(nn.Module):
    """A bilinear attention layer over a sequence X w.r.t y:
    * o_i = x_i'Wy for x_i in X.
    """
    def __init__(self, x_size, y_size, opt, identity=False):
        super(BilinearSeqAttn, self).__init__()
        if not identity:
            self.linear = nn.Linear(y_size, x_size)
        else:
            self.linear = None

    def forward(self, x, y, x_mask):
        """
        x = [batch, len, h1]
        y = [batch, h2]
        x_mask = [batch, len]
        """
        x = dropout(x, p=my_dropout_p, training=self.training)
        y = dropout(y, p=my_dropout_p, training=self.training)

        Wy = self.linear(y) if self.linear is not None else y
        xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
        xWy.data.masked_fill_(x_mask.data, -float('inf'))
        return xWy # [batch,len]

你可能感兴趣的:(PyTorch)