Similarity Reasoning and Filtration for Image-Text Matching

    • 介绍
    • 方法
      • 特征提取
        • 图特征提取
        • 文本特征提取
      • 相似性表示学习
      • SGR(相似图推理)
      • SAF(相似注意过滤)
      • 损失函数
    • 实验结果

引用:Diao, Haiwen, et al. “Similarity reasoning and filtration for image-text matching.” arXiv preprint arXiv:2101.01368 (2021).



  1. 大多方法使用标量的方法计算局部特征之间的相似性,也就是说会直接转换成一个数值,而数值是很难描述区域和单词之间的关联的
  2. 在计算完区域和单词之间的潜在对齐后,部分模型就直接用最大池化或者平均池化的方法计算出全局相似。无论是哪种方法,都会阻碍全局对齐和局部对齐之间的信息交流。(比如SCAN最后用的就是一个平均池化的方法: S A V G ( I , T ) = ∑ i = 1 k R ( v i , a i t ) k S_{AVG}(I,T)=\frac{\sum_{i=1}^{k}R(v_{i},a^{t}_{i})}{k} SAVG(I,T)=ki=1kR(vi,ait)
  3. 很少去考虑一些没有太大意义的对齐的干扰,比如下图中"a"和"in"和其他实例之间的关系
    【论文阅读】Similarity Reasoning and Filtration for Image-Text Matching_第1张图片


  1. 首先捕捉整个图像和句子之间的全局对齐,以及图像区域和句子单词之间的局部对齐。在这里,使用基于向量的相似性表示来更有效地表示这种跨模态关联;
  2. 使用SGR相似图推理模块来捕捉局部对齐和全局对齐之间的关系,从而推理更为准确的图像文本相似性,在这里,SGR模块基于GCNN图卷积神经网络构成;
  3. 使用SAF相似性注意过滤模块来聚合所有具有不同显著性分数的对齐,从而减少无意义的干扰。



【论文阅读】Similarity Reasoning and Filtration for Image-Text Matching_第2张图片




使用Faster RCNN提取图特征,添加一个全连接层转成d维向量,得到每个区域的表示 V = { v 1 , . . . , v k } V=\{v_{1},...,v_{k}\} V={v1,...,vk},这里和SCAN是一致的。

然后,在每个区域上执行自注意力机制,该机制采用平均特征 q v ˉ = 1 K ∑ i = 1 K v i \bar{q_{v}}=\frac{1}{K}\sum_{i=1}^{K}v_{i} qvˉ=K1i=1Kvi作为查询并汇总所有区域以获得全局表示 v ˉ \bar{v} vˉ。代码对应部分如下:

class VisualSA(nn.Module):
    def __init__(self, embed_dim, dropout_rate, num_region):
        super(VisualSA, self).__init__()

        self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim),
                                             nn.Tanh(), nn.Dropout(dropout_rate))
        self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim),
                                              nn.Tanh(), nn.Dropout(dropout_rate))
        self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1))

        self.softmax = nn.Softmax(dim=1)

    # local (batch, 36, d=1024)
    # global (batch, d=1024)
    def forward(self, local, raw_global):
        # compute embedding of local regions and raw global image
        l_emb = self.embedding_local(local)
        g_emb = self.embedding_global(raw_global)

        # compute the normalized weights, shape: (batch_size, 36)
        g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1)
        common = l_emb.mul(g_emb)
        weights = self.embedding_common(common).squeeze(2)
        weights = self.softmax(weights)

        # compute final image, shape: (batch_size, 1024)
        new_global = (weights.unsqueeze(2) * local).sum(dim=1)
        new_global = l2norm(new_global, dim=-1)

        # new_global (shape, d=1024)
        return new_global


使用GRU提取文本特征,得到表示 T = { t 1 , . . . , t L } T=\{t_{1},...,t_{L}\} T={t1,...,tL},按照同样的方式,得到文本的全局表示。这一部分在代码中体现在class TextSA(nn.Module)



s ( x , y ; W ) = W ∣ x − y ∣ 2 ∣ ∣ W ∣ x − y ∣ 2 ∣ ∣ 2 (1) s(x,y;W)=\frac{W|x-y|^{2}}{||W|x-y|^{2}||_{2}} \tag{1} s(x,y;W)=Wxy22Wxy2(1)
其中 W ∈ R m × d W\in \mathbb{R}^{m\times d} WRm×d为一个可学习的参数矩阵,从而可以获取一个m维度的相似向量。 ∣ ⋅ ∣ 2 |\cdot |^{2} 2 ∣ ∣ ⋅ ∣ ∣ 2 ||\cdot ||_{2} 2分别表示逐元素平方和l2标准化。

这一部分在代码的class EncoderSimilarity(nn.Module)中:

def l2norm(X, dim=-1, eps=1e-8):
    """L2-normalize columns of X"""
    norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
    X = torch.div(X, norm)
    return X

self.sim_tranloc_w = nn.Linear(embed_size, sim_dim)
self.sim_tranglo_w = nn.Linear(embed_size, sim_dim)

sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2)
sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)

sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2)
sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)


利用上面的公式得到 s g = s ( v ˉ , t ˉ ; W g ) s^{g}=s(\bar{v},\bar{t};W_{g}) sg=s(vˉ,tˉ;Wg),那么这里的 W g W_{g} Wg就用于学习全局相似表示


和SCAN一致,注意力权重计算为: α i j = e x p ( λ c ^ i j ) ∑ i = 1 K e x p ( λ c ^ i j ) \alpha _{ij}=\frac{exp(\lambda \hat{c}_{ij})}{\sum_{i=1}^{K}exp(\lambda \hat{c}_{ij})} αij=i=1Kexp(λc^ij)exp(λc^ij) c i j c_{ij} cij表示区域特征 v i v_{i} vi和词特征 t j t_{j} tj之间的余弦距离,也就是相似程度,这里用的是标量, c ^ i j \hat{c}_{ij} c^ij为标准化后的结果。由此得到生成的视觉特征(第j个单词对整幅图像的贡献程度) a j v = ∑ i = 1 K α i j v i a_{j}^{v}=\sum_{i=1}^{K}\alpha_{ij}v_{i} ajv=i=1Kαijvi

这样,得到 a j v a_{j}^{v} ajv t j t_{j} tj之间(即第j个单词和整个图像之间的相关性)的局部相似表示为: s j l = s ( a j v , t j ; W t ) s_{j}^{l}=s(a_{j}^{v},t_{j};W_{t}) sjl=s(ajv,tj;Wt)


def SCAN_attention(query, context, smooth, eps=1e-8):
    query: (n_context, queryL, d)
    context: (n_context, sourceL, d)
    # --> (batch, d, queryL)
    queryT = torch.transpose(query, 1, 2)

    # (batch, sourceL, d)(batch, d, queryL)
    # --> (batch, sourceL, queryL)
    attn = torch.bmm(context, queryT)

    attn = nn.LeakyReLU(0.1)(attn)
    attn = l2norm(attn, 2)

    # --> (batch, queryL, sourceL)
    attn = torch.transpose(attn, 1, 2).contiguous()
    # --> (batch, queryL, sourceL
    attn = F.softmax(attn*smooth, dim=2)

    # --> (batch, sourceL, queryL)
    attnT = torch.transpose(attn, 1, 2).contiguous()

    # --> (batch, d, sourceL)
    contextT = torch.transpose(context, 1, 2)
    # (batch x d x sourceL)(batch x sourceL x queryL)
    # --> (batch, d, queryL)
    weightedContext = torch.bmm(contextT, attnT)
    # --> (batch, queryL, d)
    weightedContext = torch.transpose(weightedContext, 1, 2)
    weightedContext = l2norm(weightedContext, dim=-1)

    return weightedContext



将所有单词的局部相似度表示和文本的全局相似度表示作为图节点 N = { s 1 l , . . , s L l , s g } N=\{s_{1}^{l},..,s_{L}^{l},s^{g}\} N={s1l,..,sLl,sg},这里的节点都是m维向量(代码中m取256),代码中如下:

# concat the global and local alignments
sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)             # batch, n_word+1, sim_d

使用以下公式计算节点 s q s_{q} sq s p s_{p} sp之间的边:
e ( s p , s q ; W i n , W o u t ) = e x p ( ( W i n s p ) ( W o u t s q ) ) ∑ q e x p ( ( W i n s p ) ( W o u t s q ) ) (2) e(s_{p},s_{q};W_{in},W_{out})=\frac{exp((W_{in}s_{p})(W_{out}s_{q}))}{\sum_{q}exp((W_{in}s_{p})(W_{out}s_{q}))} \tag{2} e(sp,sq;Win,Wout)=qexp((Winsp)(Woutsq))exp((Winsp)(Woutsq))(2)
两个 W W W分别为输入节点和输出节点的线性变化。从这里也能看出,节点之间的边是存在方向的。

s ^ p n = ∑ q e ( s p n , s q n ; W i n n , W o u t n ) ⋅ s q n s p n + 1 = R e L U ( W r n s ^ p n ) (3) \hat{s}_{p}^{n}=\sum_{q}e(s_{p}^{n},s_{q}^{n};W_{in}^{n},W_{out}^{n})\cdot s_{q}^{n}\\ s_{p}^{n+1}=ReLU(W_{r}^{n}\hat{s}_{p}^{n}) \tag{3} s^pn=qe(spn,sqn;Winn,Woutn)sqnspn+1=ReLU(Wrns^pn)(3)

其中 s p 0 s_{p}^{0} sp0 s q 0 s_{q}^{0} sq0为步骤n=0时候的从 N N N中提取出来的节点(即初始节点), W r n , W i n n , W o u t n W_{r}^{n},W_{in}^{n},W_{out}^{n} Wrn,Winn,Woutn为每一步中的学习出来的参数,每一步结束后, s p n s_{p}^{n} spn都会被 s p n + 1 s_{p}^{n+1} spn+1取代



class GraphReasoning(nn.Module):
    Perform the similarity graph reasoning with a full-connected graph
    Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
    Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256)
    def __init__(self, sim_dim):
        super(GraphReasoning, self).__init__()

        self.graph_query_w = nn.Linear(sim_dim, sim_dim)
        self.graph_key_w = nn.Linear(sim_dim, sim_dim)
        self.sim_graph_w = nn.Linear(sim_dim, sim_dim)
        self.relu = nn.ReLU()


    # sim_emb (batch, n_word+1, sim_d)
    def forward(self, sim_emb):
        sim_query = self.graph_query_w(sim_emb)
        sim_key = self.graph_key_w(sim_emb)
        # batch, n_word+1, n_word+1
        sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1)
        # batch, n_word+1, sim_d
        sim_sgr = torch.bmm(sim_edge, sim_emb)
        # batch, n_word+1, sim_d
        sim_sgr = self.relu(self.sim_graph_w(sim_sgr))
        return sim_sgr


对每一个节点 s p s_{p} sp计算一个聚合权重
β p = δ ( B N ( W f s p ) ) ∑ s q ∈ N δ ( B N ( W f s p ) ) (4) \beta _{p}=\frac{\delta (BN(W_{f}s_{p}))}{\sum _{s_{q}\in N}\delta (BN(W_{f}s_{p}))} \tag{4} βp=sqNδ(BN(Wfsp))δ(BN(Wfsp))(4)
其中 δ \delta δ 为Sigmoid函数, B N BN BN表示batch normalization, W f ∈ R m × 1 W_{f}\in \mathbb{R}^{m\times 1} WfRm×1为一个线性变换

然后,使用公式 s f = ∑ s p ∈ N β p s p s_{f}=\sum_{s_{p}\in N}\beta _{p}s_{p} sf=spNβpsp将所有的相似性特征聚合起来,可以看到,通过这样的方式,使得诸如"the","be"等无意义的对齐的系数会变得很小,最终等同于被过滤掉。



class AttentionFiltration(nn.Module):
    Perform the similarity Attention Filtration with a gate-based attention
    Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
    Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256)
    def __init__(self, sim_dim):
        super(AttentionFiltration, self).__init__()

        self.attn_sim_w = nn.Linear(sim_dim, 1)
        self.bn = nn.BatchNorm1d(1)


    # input (batch, n_word+1, sim_d)
    def forward(self, sim_emb):
        sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1)
        sim_saf = torch.matmul(sim_attn, sim_emb)
        sim_saf = l2norm(sim_saf.squeeze(1), dim=-1)
        return sim_saf


L r ( v , t ) = [ γ − S r ( v , t ) + S r ( v , t − ) ] + [ γ − S r ( v , t ) + S r ( v − , t ) ] + L_{r}(v,t)=[\gamma -S_{r}(v,t)+S_{r}(v,t^{-})]_{+}[\gamma -S_{r}(v,t)+S_{r}(v^{-},t)]_{+} Lr(v,t)=[γSr(v,t)+Sr(v,t)]+[γSr(v,t)+Sr(v,t)]+

其中 ( v , t ) (v,t) (v,t)为一个(图像,文本)对, t − t^{-} t v − v^{-} v为最大负样本, γ \gamma γ为边缘参数, S r S_{r} Sr为SGR或者SAF预测的相似函数。这一部分和SCAN基本一致


【论文阅读】Similarity Reasoning and Filtration for Image-Text Matching_第3张图片


【论文阅读】Similarity Reasoning and Filtration for Image-Text Matching_第4张图片

