torch.einsum是个好东西,就是输入数据多于2个,就有点看不懂了。(改成了使用torch.matmul主要是为了将代码和论文公式对应上,也验证了计算的结果应该是一致的)
源码来源:https://github.com/jnhwkim/ban-vqa
以下代码位于此处,其中:
1)forward
函数用来计算Bilinear Attention Map(输入分别是视觉编码v
和问题编码q
),也就是注意力权重;
2)forward_with_weights
函数基于注意力权重w
来进行视觉编码v
和问题编码q
的融合。
计算权重和融合分别是两个独立的层中的操作,两者之间是不共享v_net
和q_net
的参数的。
其中,相关数据维度如下:
# 1 forward函数:
v_ [B, M, D]
q_ [B, L, D]
# 2 forward_with_weights函数:
v_ [B, M, D]
q_ [B, L, D]
w [B, M, L]
# low-rank bilinear pooling using einsum
def forward(self, v, q):
...
elif self.h_out <= self.c:
v_ = self.dropout(self.v_net(v))
q_ = self.q_net(q)
logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
return logits # b x h_out x v x q
...
对应论文中的公式(8)softmax函数中内容(这个\mathbbm{1}打不出来,-_-),其中h_mat
即公式中的 p \mathrm{p} p(公式中未考虑偏置bias),两个Linear层v_net
和q_net
对应了公式中的权重矩阵 U \mathrm{U} U和 V \mathrm{V} V,v
和q
即公式中的 X \mathrm{X} X和 Y \mathrm{Y} Y。
----------------------------------------------分割线---------------------------------------------
1、【公式(8)softmax内表达式与公式(9)等价,公式(8)直接求出权重矩阵,公式(9)是权重矩阵中每一个单项数据原始计算形式。个人感觉转换成公式(8)更方便使用代码实现,因为X与Y的长度如果都多于1的话,其Hadamard积计算就不是那么方便(pytorch中反正好像没有长度不等的矩阵求Hadamard积的函数实现)】
2、【X和Y可对比为现在注意力机制中常说的Query和Key,如果Query/X长度始终是1,那么也可以使用Linear层实现公式(9),而不需要转换为公式(8)进行实现】
----------------------------------------------分割线完-------------------------------------------
如果不用torch.einsum()
函数,按照公式,代码可如下:
# low-rank bilinear pooling using einsum
def forward(self, v, q):
...
elif self.h_out <= self.c:
"""
v: [B, M, D']
q: [B, L, D']
h_mat: [1, h_out, 1, D], h_out默认等于8
h_bias: [1, h_out, 1, 1]
D = self.k * D', self.k = 3
"""
v_ = self.dropout(self.v_net(v)) # [B, M, D]
q_ = self.q_net(q) # [B, L, D]
# logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
logits = torch.matmul((h_mat * v_.unsqueeze(1)), q_.unsqueeze(1).transpose(-1, -2)) + h_bias
return logits # b x h_out x M x L
...
def forward_with_weights(self, v, q, w):
v_ = self.v_net(v) # b x v x d
q_ = self.q_net(q) # b x q x d
logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
if 1 < self.k:
logits = logits.unsqueeze(1) # b x 1 x d
logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
return logits
对应论文中的公式(5)和公式(6),下标k表示在数据维度上(即维度D)的遍历,完整 f ′ ∈ R K \mathrm{f'}\in \mathbb{R}^K f′∈RK。其中v
和q
和w
分别代表 X \mathrm{X} X和 Y \mathrm{Y} Y和 A \mathcal{A} A(这里面的K本质上应该就是数据维度D)
如果不用torch.einsum()
函数,按照公式,代码可如下:
def forward_with_weights(self, v, q, w):
"""
v: [B, M, D']
q: [B, L, D']
w: [B, M, L]
D = self.k * D', self.k = 3
"""
v_ = self.v_net(v) # [B, M, D]
q_ = self.q_net(q) # [B, L, D]
# logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
# 运算过程与维度D无关,因此需要交换维度
logits = torch.matmul(torch.matmul(v_.permute(0,2,1).unsqueeze(-2), w.unsqueeze(1)), q_.permute(0, 2, 1).unsqueeze(-1)).squeeze()
if 1 < self.k:
logits = logits.unsqueeze(1) # [B, 1, D]
logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
return logits # [B, D / self.k]