Bilinear Attention Networks 代码记录

torch.einsum是个好东西,就是输入数据多于2个,就有点看不懂了。(改成了使用torch.matmul主要是为了将代码和论文公式对应上,也验证了计算的结果应该是一致的)
源码来源:https://github.com/jnhwkim/ban-vqa
以下代码位于此处,其中:
1)forward函数用来计算Bilinear Attention Map(输入分别是视觉编码v和问题编码q),也就是注意力权重;
2)forward_with_weights函数基于注意力权重w来进行视觉编码v和问题编码q的融合。
计算权重和融合分别是两个独立的层中的操作,两者之间是不共享v_netq_net的参数的。
Bilinear Attention Networks 代码记录_第1张图片
其中,相关数据维度如下:

# 1 forward函数:
v_ [B, M, D]
q_ [B, L, D]
# 2 forward_with_weights函数:
v_ [B, M, D]
q_ [B, L, D]
w  [B, M, L]

1 forward函数

# low-rank bilinear pooling using einsum
def forward(self, v, q):
	...
	elif self.h_out <= self.c:
		v_ = self.dropout(self.v_net(v))
		q_ = self.q_net(q)
		logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
		return logits # b x h_out x v x q
	...

对应论文中的公式(8)softmax函数中内容(这个\mathbbm{1}打不出来,-_-),其中h_mat即公式中的 p \mathrm{p} p(公式中未考虑偏置bias),两个Linear层v_netq_net对应了公式中的权重矩阵 U \mathrm{U} U V \mathrm{V} Vvq即公式中的 X \mathrm{X} X Y \mathrm{Y} Y
Bilinear Attention Networks 代码记录_第2张图片
----------------------------------------------分割线---------------------------------------------
1、【公式(8)softmax内表达式与公式(9)等价,公式(8)直接求出权重矩阵,公式(9)是权重矩阵中每一个单项数据原始计算形式。个人感觉转换成公式(8)更方便使用代码实现,因为X与Y的长度如果都多于1的话,其Hadamard积计算就不是那么方便(pytorch中反正好像没有长度不等的矩阵求Hadamard积的函数实现)】
2、【X和Y可对比为现在注意力机制中常说的Query和Key,如果Query/X长度始终是1,那么也可以使用Linear层实现公式(9),而不需要转换为公式(8)进行实现】
----------------------------------------------分割线完-------------------------------------------
如果不用torch.einsum()函数,按照公式,代码可如下:

		
# low-rank bilinear pooling using einsum
def forward(self, v, q):
	...
	elif self.h_out <= self.c:
		"""
		v: [B, M, D']
		q: [B, L, D']
		h_mat:  [1, h_out, 1, D], h_out默认等于8
		h_bias: [1, h_out, 1, 1]
		D = self.k * D', self.k = 3
		""" 	
		v_ = self.dropout(self.v_net(v)) # [B, M, D]
		q_ = self.q_net(q) 				 # [B, L, D]
		# logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias 		 
		logits = torch.matmul((h_mat * v_.unsqueeze(1)), q_.unsqueeze(1).transpose(-1, -2)) + h_bias
		
		return logits # b x h_out x M x L
	...

2 forward_with_weights函数

def forward_with_weights(self, v, q, w):
	v_ = self.v_net(v) # b x v x d
	q_ = self.q_net(q) # b x q x d
	logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
	if 1 < self.k:
		logits = logits.unsqueeze(1) # b x 1 x d
		logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
	return logits

对应论文中的公式(5)和公式(6),下标k表示在数据维度上(即维度D)的遍历,完整 f ′ ∈ R K \mathrm{f'}\in \mathbb{R}^K fRK。其中vqw分别代表 X \mathrm{X} X Y \mathrm{Y} Y A \mathcal{A} A(这里面的K本质上应该就是数据维度D)

如果不用torch.einsum()函数,按照公式,代码可如下:

def forward_with_weights(self, v, q, w):
	"""
	v: [B, M, D']
	q: [B, L, D']
	w: [B, M, L]
	D = self.k * D', self.k = 3
	"""
	v_ = self.v_net(v) # [B, M, D]
	q_ = self.q_net(q) # [B, L, D]
	# logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
	# 运算过程与维度D无关,因此需要交换维度
	logits = torch.matmul(torch.matmul(v_.permute(0,2,1).unsqueeze(-2), w.unsqueeze(1)), q_.permute(0, 2, 1).unsqueeze(-1)).squeeze()
	if 1 < self.k:
		logits = logits.unsqueeze(1) # [B, 1, D]
		logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
	return logits  # [B, D / self.k]

你可能感兴趣的:(深度瞎搞,计算机幻觉,多模态,自然语言处理,深度学习,pytorch,神经网络)