class GetSpanStartEnd(nn.Module):
# supports MLP attention and GRU for pointer network updating
def __init__(self, x_size, h_size, opt, do_indep_attn=True, attn_type="Bilinear", do_ptr_update=True):
super(GetSpanStartEnd, self).__init__()
self.attn = BilinearSeqAttn(x_size, h_size, opt)
self.attn2 = BilinearSeqAttn(x_size, h_size, opt) if do_indep_attn else None
self.rnn = nn.GRUCell(x_size, h_size) if do_ptr_update else None
# x -- doc_hiddens [10,384,250]
# h0 -- question_avg_hidden [10,125]
# x_mask [10,384]
def forward(self, x, h0, x_mask):
"""
x = [batch, len, x_hidden_size]
h0 = [batch, h_size]
x_mask = [batch, len]
"""
start_scores = self.attn(x, h0, x_mask) # [10,384]
# start_scores [batch, len]
if self.rnn is not None:
ptr_net_in = torch.bmm(F.softmax(start_scores, dim=1).unsqueeze(1), x).squeeze(1) # [10,250]
ptr_net_in = dropout(ptr_net_in, p=my_dropout_p, training=self.training)
h0 = dropout(h0, p=my_dropout_p, training=self.training)
h1 = self.rnn(ptr_net_in, h0) # [10,125]
# h1 same size as h0
else:
h1 = h0
end_scores = self.attn(x, h1, x_mask) if self.attn2 is None else\
self.attn2(x, h1, x_mask)
# end_scores = batch * len
return start_scores, end_scores # [10,384]
class BilinearSeqAttn(nn.Module):
"""A bilinear attention layer over a sequence X w.r.t y:
* o_i = x_i'Wy for x_i in X.
"""
def __init__(self, x_size, y_size, opt, identity=False):
super(BilinearSeqAttn, self).__init__()
if not identity:
self.linear = nn.Linear(y_size, x_size)
else:
self.linear = None
def forward(self, x, y, x_mask):
"""
x = [batch, len, h1]
y = [batch, h2]
x_mask = [batch, len]
"""
x = dropout(x, p=my_dropout_p, training=self.training)
y = dropout(y, p=my_dropout_p, training=self.training)
Wy = self.linear(y) if self.linear is not None else y
xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
xWy.data.masked_fill_(x_mask.data, -float('inf'))
return xWy # [batch,len]