class AdditiveAttention(nn.Module):
def __init__(self, keys_size, queries_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_q = nn.Linear(queries_size, num_hiddens, bias=False)
self.W_k = nn.Linear(keys_size, num_hiddens, bias=False)
self.W_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values):
queries, keys = self.W_q(queries), self.W_k(keys)
'''
queries --> [batch_size, queries_length, num_hiddens]
keys --> [batch_size, keys_length, num_hiddens]'''
features = queries.unsqueeze(2) + keys.unsqueeze(1)
'''
queries.unsqueeze(2) --> [batch_size, queries_length, 1, num_hiddens]
keys.unsqueeze(1) --> [batch_size, 1, keys_length, num_hiddens]
features --> [batch_size, queries_length, keys_length, num_hiddens] '''
features = torch.tanh(features)
scores = self.W_v(features).squeeze(-1)
'''
self.W_v(features) --> [batch_size, queries_length, keys_length, 1]
scores--> [batch_size, queries_length, keys_length]'''
self.attention_weights = F.softmax(scores, dim=1)
'''
self.attention_weights --> [batch_size, queries_length, keys_length]'''
return torch.bmm(self.dropout(self.attention_weights), values)
'''
output --> [batch_size, queries_length, value_features_num]
'''
#############
### 实例测试 ###
#############
queries, keys = torch.normal(0, 1, (2, 2, 20)), torch.ones((2, 10, 2))
# `values` 的小批量数据集中,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
2, 1, 1)
attention = AdditiveAttention(
keys_size=2, queries_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
output = attention(queries, keys, values)
'''
output:
tensor([[[ 91.1298, 96.1926, 101.2553, 106.3181],
[ 88.8702, 93.8074, 98.7447, 103.6819]],
[[ 92.0438, 97.1574, 102.2709, 107.3845],
[ 87.9562, 92.8426, 97.7291, 102.6155]]]
shape : [2,2,4]
'''
aa = torch.arange(12).reshape(1,1,4,3)
'''output:
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]]])'''
bb = torch.arange(6).reshape(1,2,1,3)
'''output:
tensor([[[[0, 1, 2]],
[[3, 4, 5]]]])'''
aa + bb
'''output:
tensor([[[[ 0, 2, 4],
[ 3, 5, 7],
[ 6, 8, 10],
[ 9, 11, 13]],
[[ 3, 5, 7],
[ 6, 8, 10],
[ 9, 11, 13],
[12, 14, 16]]]])'''
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values):
'''
queries --> [batch_size, queries_length, queries_feature_num]
keys --> [batch_size, keys_values_length, keys_features_num]
values --> [barch_size, keys_values_length, values_features_num]
点积模型中: queries_features_num = keys_features_num
'''
d = queries.shape[-1]
'''交换keys的后两个维度,相当于公式中的转置'''
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
self.attention_weights = F.softmax(scores, dim=1)
return torch.bmm(self.dropout(self.attention_weights), values)
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
dot_output = attention(queries, keys, values)
print(dot_output)
'''
dot_output:
tensor([[[180., 190., 200., 210.]],
[[180., 190., 200., 210.]]])
'''
attention机制常与sequence2sequence相结合使用,相应的查询(queries)、键(keys)和值(values)分别为:
sequence2sequence with attention的基本流程如下:
sequence2sequence:包括编码层和解码层两个部分,其中attention机制加入到解码层中,先定义编码层,代码如下:
class Encoder(nn.Module):
def __init__(self, inputs_dim, num_hiddens, hiddens_layers):
super(Encoder, self).__init__()
self.rnn1 = nn.GRU(
input_size=inputs_dim, hidden_size=num_hiddens,
num_layers=hiddens_layers)
def forward(self, inputs):
'''由于nn.GRU没有设置 batch_first=True
因此输入的维度排列:[time_step_num, batch_size, num_features]
输出维度为:
output: [time_step_num, batch_size, hiddens_num]
hidSta: [num_layers, batch_size, hiddens_num]
'''
inputs = inputs.permute(1, 0, 2)
encOut, hidSta = self.rnn1(inputs)
return encOut, hidSta
class AttentionDecoder(nn.Module):
def __init__(
self, inputs_dim, num_hiddens, num_layers, outputs_dim, dropout):
super(AttentionDecoder, self).__init__()
self.attention = AdditiveAttention(
num_hiddens, num_hiddens, num_hiddens, dropout)
self.rnn = nn.GRU(
inputs_dim + num_hiddens, num_hiddens, num_layers,
dropout=dropout)
self.dense = nn.Linear(num_hiddens, outputs_dim)
def forward(self, inputs, states):
'''
inputs: [batch_size, time_step_num, features]
states:
enc_ouptut, enc_hidden_state
'''
enc_outputs, hidden_state = states
'''将enc_output的维度变为[batch_size, time_step_num, enc_hidden_num]'''
enc_outputs = enc_outputs.permute(1, 0, 2)
inputs = inputs.permute(1, 0, 2)
'''将inputs的维度变为[time_step_num, batch_size, features_num]'''
outputs, self._attention_weights = [], []
'''对每一时间步的inputs进行计算,并于上下文信息进行融合'''
for x in inputs:
'''提取enc_hidden最后一层的输出作为query,并在第2维添加维度
hidden_state[-1] : [batch_size, enc_hidden_num]
--> [batch_size, 1, enc_hidden_num]'''
query = hidden_state[-1].unsqueeze(dim=1)
import pdb;pdb.set_trace()
'''context: [batch_size, query_length=1, hiddens_num]'''
context = self.attention(query, enc_outputs, enc_outputs)
x = torch.cat((context, x.unsqueeze(dim=1)), dim=-1)
'''更新hidden_state'''
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1, 0, 2), [enc_outputs, hidden_state]
##########
### 实例 ###
#########
encoder = Encoder(inputs_dim=10, num_hiddens=20, hiddens_layers=2)
decoder = AttentionDecoder(
inputs_dim=10, num_hiddens=20, num_layers=2, outputs_dim=8, dropout=0.1)
inputs = torch.normal(0, 1, (4, 8, 10))
state = encoder(inputs)
dec_inputs = torch.normal(0, 1, (4, 1, 10))
dec_output, state = decoder(dec_inputs, state)
print(dec_output.shape)
'''
output:
[4, 1, 8]
'''
class EncoderDecoder(nn.Module):
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
state = self.encoder(enc_X, *args)
dec_state = self.decoder(dec_X, state)
return dec_state
net = EncoderDecoder(encoder, decoder)
output = net(inputs, dec_inputs)
print(output[0].shape) # -->[4,1,8]