发表时间:2021
引用:Diao, Haiwen, et al. “Similarity reasoning and filtration for image-text matching.” arXiv preprint arXiv:2101.01368 (2021).
论文地址:https://openaccess.thecvf.com
代码地址:https://github.com/Paranioar/SGRAF
作者还是为了实现更为细粒度的对齐。虽然此前使用全局对齐或者局部对齐的方式已经取得了一些成效,但是作者认为,当下的模型方法还是存在三点问题:
那么,为了解决上述三点问题,作者提出了一种相似图推理和注意过滤网络。具体来说:
注意的是,这里不是说先用SGR再用SAF,也就是这两个模块没有先后顺序,作者进行实验的时候也是分模块进行的实验
整体的模型结构如上所示,可以看到,经过特征提取后,就进行全局对齐和局部对齐,然后将对齐的结果分别进入SGR和SAF模块进行相似度的计算,这两个模块是独立的。那么下面也将分成四个部分进行说明重点说明。
使用Faster RCNN提取图特征,添加一个全连接层转成d维向量,得到每个区域的表示 V = { v 1 , . . . , v k } V=\{v_{1},...,v_{k}\} V={v1,...,vk},这里和SCAN是一致的。
然后,在每个区域上执行自注意力机制,该机制采用平均特征 q v ˉ = 1 K ∑ i = 1 K v i \bar{q_{v}}=\frac{1}{K}\sum_{i=1}^{K}v_{i} qvˉ=K1∑i=1Kvi作为查询并汇总所有区域以获得全局表示 v ˉ \bar{v} vˉ。代码对应部分如下:
class VisualSA(nn.Module):
def __init__(self, embed_dim, dropout_rate, num_region):
super(VisualSA, self).__init__()
self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim),
nn.BatchNorm1d(num_region),
nn.Tanh(), nn.Dropout(dropout_rate))
self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim),
nn.BatchNorm1d(embed_dim),
nn.Tanh(), nn.Dropout(dropout_rate))
self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1))
self.init_weights()
self.softmax = nn.Softmax(dim=1)
# local (batch, 36, d=1024)
# global (batch, d=1024)
def forward(self, local, raw_global):
# compute embedding of local regions and raw global image
l_emb = self.embedding_local(local)
g_emb = self.embedding_global(raw_global)
# compute the normalized weights, shape: (batch_size, 36)
g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1)
common = l_emb.mul(g_emb)
weights = self.embedding_common(common).squeeze(2)
weights = self.softmax(weights)
# compute final image, shape: (batch_size, 1024)
new_global = (weights.unsqueeze(2) * local).sum(dim=1)
new_global = l2norm(new_global, dim=-1)
# new_global (shape, d=1024)
return new_global
使用GRU提取文本特征,得到表示 T = { t 1 , . . . , t L } T=\{t_{1},...,t_{L}\} T={t1,...,tL},按照同样的方式,得到文本的全局表示。这一部分在代码中体现在class TextSA(nn.Module)
中
向量相似函数
之前说过,作者没有使用余弦距离或欧几里得距离来计算相似性标量,而是用的一个相似向量来表示不同模态之间的相似关联程度。那么在这里,作者就是用的如下函数:
s ( x , y ; W ) = W ∣ x − y ∣ 2 ∣ ∣ W ∣ x − y ∣ 2 ∣ ∣ 2 (1) s(x,y;W)=\frac{W|x-y|^{2}}{||W|x-y|^{2}||_{2}} \tag{1} s(x,y;W)=∣∣W∣x−y∣2∣∣2W∣x−y∣2(1)
其中 W ∈ R m × d W\in \mathbb{R}^{m\times d} W∈Rm×d为一个可学习的参数矩阵,从而可以获取一个m维度的相似向量。 ∣ ⋅ ∣ 2 |\cdot |^{2} ∣⋅∣2和 ∣ ∣ ⋅ ∣ ∣ 2 ||\cdot ||_{2} ∣∣⋅∣∣2分别表示逐元素平方和l2标准化。
这一部分在代码的class EncoderSimilarity(nn.Module)
中:
def l2norm(X, dim=-1, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
self.sim_tranloc_w = nn.Linear(embed_size, sim_dim)
self.sim_tranglo_w = nn.Linear(embed_size, sim_dim)
sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2)
sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2)
sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
全局相似表示
利用上面的公式得到 s g = s ( v ˉ , t ˉ ; W g ) s^{g}=s(\bar{v},\bar{t};W_{g}) sg=s(vˉ,tˉ;Wg),那么这里的 W g W_{g} Wg就用于学习全局相似表示
局部相似表示
和SCAN一致,注意力权重计算为: α i j = e x p ( λ c ^ i j ) ∑ i = 1 K e x p ( λ c ^ i j ) \alpha _{ij}=\frac{exp(\lambda \hat{c}_{ij})}{\sum_{i=1}^{K}exp(\lambda \hat{c}_{ij})} αij=∑i=1Kexp(λc^ij)exp(λc^ij), c i j c_{ij} cij表示区域特征 v i v_{i} vi和词特征 t j t_{j} tj之间的余弦距离,也就是相似程度,这里用的是标量, c ^ i j \hat{c}_{ij} c^ij为标准化后的结果。由此得到生成的视觉特征(第j个单词对整幅图像的贡献程度) a j v = ∑ i = 1 K α i j v i a_{j}^{v}=\sum_{i=1}^{K}\alpha_{ij}v_{i} ajv=∑i=1Kαijvi
这样,得到 a j v a_{j}^{v} ajv和 t j t_{j} tj之间(即第j个单词和整个图像之间的相关性)的局部相似表示为: s j l = s ( a j v , t j ; W t ) s_{j}^{l}=s(a_{j}^{v},t_{j};W_{t}) sjl=s(ajv,tj;Wt)
这一部分在代码中,直接用了SCAN的代码部分:
def SCAN_attention(query, context, smooth, eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, queryL, sourceL
attn = F.softmax(attn*smooth, dim=2)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
weightedContext = l2norm(weightedContext, dim=-1)
return weightedContext
需要注意的是,这里只有特定单词和对应图像区域之间的关联,并不像SCAN一样还存在特定区域和单词之间的关联
将所有单词的局部相似度表示和文本的全局相似度表示作为图节点 N = { s 1 l , . . , s L l , s g } N=\{s_{1}^{l},..,s_{L}^{l},s^{g}\} N={s1l,..,sLl,sg},这里的节点都是m维向量(代码中m取256),代码中如下:
# concat the global and local alignments
sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) # batch, n_word+1, sim_d
使用以下公式计算节点 s q s_{q} sq和 s p s_{p} sp之间的边:
e ( s p , s q ; W i n , W o u t ) = e x p ( ( W i n s p ) ( W o u t s q ) ) ∑ q e x p ( ( W i n s p ) ( W o u t s q ) ) (2) e(s_{p},s_{q};W_{in},W_{out})=\frac{exp((W_{in}s_{p})(W_{out}s_{q}))}{\sum_{q}exp((W_{in}s_{p})(W_{out}s_{q}))} \tag{2} e(sp,sq;Win,Wout)=∑qexp((Winsp)(Woutsq))exp((Winsp)(Woutsq))(2)
两个 W W W分别为输入节点和输出节点的线性变化。从这里也能看出,节点之间的边是存在方向的。
这样,就可以利用构造好的节点和边,通过不断更新来进行相似图推理:
s ^ p n = ∑ q e ( s p n , s q n ; W i n n , W o u t n ) ⋅ s q n s p n + 1 = R e L U ( W r n s ^ p n ) (3) \hat{s}_{p}^{n}=\sum_{q}e(s_{p}^{n},s_{q}^{n};W_{in}^{n},W_{out}^{n})\cdot s_{q}^{n}\\ s_{p}^{n+1}=ReLU(W_{r}^{n}\hat{s}_{p}^{n}) \tag{3} s^pn=q∑e(spn,sqn;Winn,Woutn)⋅sqnspn+1=ReLU(Wrns^pn)(3)
其中 s p 0 s_{p}^{0} sp0和 s q 0 s_{q}^{0} sq0为步骤n=0时候的从 N N N中提取出来的节点(即初始节点), W r n , W i n n , W o u t n W_{r}^{n},W_{in}^{n},W_{out}^{n} Wrn,Winn,Woutn为每一步中的学习出来的参数,每一步结束后, s p n s_{p}^{n} spn都会被 s p n + 1 s_{p}^{n+1} spn+1取代
对相似度进行N步迭代推理,并以最后一步全局节点的输出作为推理的相似度表示,然后将其送入一个全连接层来推断最终的相似度分数
在代码中,N取3,具体如下:
class GraphReasoning(nn.Module):
"""
Perform the similarity graph reasoning with a full-connected graph
Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256)
"""
def __init__(self, sim_dim):
super(GraphReasoning, self).__init__()
self.graph_query_w = nn.Linear(sim_dim, sim_dim)
self.graph_key_w = nn.Linear(sim_dim, sim_dim)
self.sim_graph_w = nn.Linear(sim_dim, sim_dim)
self.relu = nn.ReLU()
self.init_weights()
# sim_emb (batch, n_word+1, sim_d)
def forward(self, sim_emb):
sim_query = self.graph_query_w(sim_emb)
sim_key = self.graph_key_w(sim_emb)
# batch, n_word+1, n_word+1
sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1)
# batch, n_word+1, sim_d
sim_sgr = torch.bmm(sim_edge, sim_emb)
# batch, n_word+1, sim_d
sim_sgr = self.relu(self.sim_graph_w(sim_sgr))
return sim_sgr
对每一个节点 s p s_{p} sp计算一个聚合权重
β p = δ ( B N ( W f s p ) ) ∑ s q ∈ N δ ( B N ( W f s p ) ) (4) \beta _{p}=\frac{\delta (BN(W_{f}s_{p}))}{\sum _{s_{q}\in N}\delta (BN(W_{f}s_{p}))} \tag{4} βp=∑sq∈Nδ(BN(Wfsp))δ(BN(Wfsp))(4)
其中 δ \delta δ 为Sigmoid函数, B N BN BN表示batch normalization, W f ∈ R m × 1 W_{f}\in \mathbb{R}^{m\times 1} Wf∈Rm×1为一个线性变换
然后,使用公式 s f = ∑ s p ∈ N β p s p s_{f}=\sum_{s_{p}\in N}\beta _{p}s_{p} sf=∑sp∈Nβpsp将所有的相似性特征聚合起来,可以看到,通过这样的方式,使得诸如"the","be"等无意义的对齐的系数会变得很小,最终等同于被过滤掉。
最终使用一个完全连接层来预测输入图像和句子之间的最终相似度,这里就是一个标量了。
相关代码如下:
class AttentionFiltration(nn.Module):
"""
Perform the similarity Attention Filtration with a gate-based attention
Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256)
"""
def __init__(self, sim_dim):
super(AttentionFiltration, self).__init__()
self.attn_sim_w = nn.Linear(sim_dim, 1)
self.bn = nn.BatchNorm1d(1)
self.init_weights()
# input (batch, n_word+1, sim_d)
def forward(self, sim_emb):
sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1)
sim_saf = torch.matmul(sim_attn, sim_emb)
sim_saf = l2norm(sim_saf.squeeze(1), dim=-1)
return sim_saf
L r ( v , t ) = [ γ − S r ( v , t ) + S r ( v , t − ) ] + [ γ − S r ( v , t ) + S r ( v − , t ) ] + L_{r}(v,t)=[\gamma -S_{r}(v,t)+S_{r}(v,t^{-})]_{+}[\gamma -S_{r}(v,t)+S_{r}(v^{-},t)]_{+} Lr(v,t)=[γ−Sr(v,t)+Sr(v,t−)]+[γ−Sr(v,t)+Sr(v−,t)]+
其中 ( v , t ) (v,t) (v,t)为一个(图像,文本)对, t − t^{-} t−和 v − v^{-} v−为最大负样本, γ \gamma γ为边缘参数, S r S_{r} Sr为SGR或者SAF预测的相似函数。这一部分和SCAN基本一致
单独看两个模块的相似结果:
在这里面,可以看到,对于SAF来说,诸如on、the、a之类的冠词等的权重非常低,这一类词就被过滤掉了