LinkNet 论文复现

Paper:LinkNet: Relational Embedding for Scene Graph
GitHub: linknet-pytorch

LinkNet 论文复现_第1张图片

LinkNet的创新点主要在于以下三个模块:

  • Relational Embedding Module:关系嵌入模块
  • Global Context Encoding Module:全局上下文编码模块
  • Geometric Layout Encoding Module:几何布局编码模块

1 关系嵌入模块

LinkNet 论文复现_第2张图片
由上图可见,作者在物体分类和边分类中分别使用了2个关系嵌入模块。
注意,该模块可以进行多次叠加,作者认为通过叠加该模块可以实现“多次跳跃”的信息传递,但随着该模块数量增加,模型效果会出现先提升后衰减的现象。

该模块使用了注意力机制来对上下文信息进行编码和解码,具体计算方法如下:
LinkNet 论文复现_第3张图片

该模块的代码实现如下:

class RelationalEmbedding(nn.Module):
    """
    Module for relational embedding
    """
    def __init__(self, input_dim, output_dim, r=2):
        super(RelationalEmbedding, self).__init__()

        self.W = nn.Linear(input_dim, int(input_dim/r))
        self.U = nn.Linear(input_dim, int(input_dim/r))
        self.H = nn.Linear(input_dim, int(input_dim/r))

        self.fc0 = nn.Linear(int(input_dim/r), input_dim)
        self.fc1 = nn.Linear(input_dim, output_dim)

    def forward(self, O0):
        """
        Forward pass for relational embedding
        :param O0: [N, input_dim] object features
        :return: O1: [N, input_dim] encoded features
                 O2: [N, output_dim] decoded features
        """
        R1 = F.softmax(torch.matmul(self.W(O0), torch.t(self.U(O0))), 1)
        O1 = O0 + self.fc0(torch.matmul(R1, self.H(O0)))
        O2 = self.fc1(O1)
        return O1, O2

2 全局上下文编码模块

LinkNet 论文复现_第4张图片
该模块通过对由RPN输出的特征进行了全局平均池化(Network in Network中的Global Average Pooling),从而得到全局上下文编码特征c,然后经过softmax进行多标签分类。

其中,我使用了PyTorch的AdaptiveAvgPool2d函数来实现全局平均池化,并使用了BCEloss作为多标签分类的损失函数。

该模块的代码实现如下:

class GlobalContextEncoding(nn.Module):
    """
    Module for global context encoding
    """
    def __init__(self, num_classes, ctx_dim=512):
        super(GlobalContextEncoding, self).__init__()

        self.glb_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Flattener(),
        )

        self.multi_score_fc = nn.Linear(ctx_dim, num_classes)

    def forward(self, features):
        """
        Forward pass for global context encoding
        :param features: [batch_size, ctx_dim, IM_SIZE/4, IM_SIZE/4] fmap features
        :return: c: [batch_size, ctx_dim] context feature
                 M: [batch_size, num_classes] softmax of multi-label distribution
        """
        c = self.glb_avg_pool(features)
        M = F.softmax(self.multi_score_fc(c), dim=1)
        return c, M


3 几何布局编码模块

LinkNet 论文复现_第5张图片
该模块通过对主语物体以及宾语物体的相对几何关系进行了显示的编码,从而提升识别物体间关系的准确率。

其中,前两维编码了相对位置,后两维编码了相对大小,具体计算如下:
LinkNet 论文复现_第6张图片
该模块的代码实现如下:

def geo_layout_enc(self, box_priors, rel_inds):
        """
        Geometric Layout Encoding
        :param box_priors: [num_rois, 4] of (xmin, ymin, xmax, ymax)
        :param rel_inds: [num_rels, 3] of (img ind, box0 ind, box1 ind)
        :return: bos: [num_rois*(num_rois-1), 4] encoded relative geometric layout: bo|s
        """
        cxcywh = center_size(box_priors.data)  # convert to (cx, cy, w, h)
        box_s = cxcywh[rel_inds[:, 1]]
        box_o = cxcywh[rel_inds[:, 2]]

        # relative location
        rlt_loc_x = torch.div((box_o[:, 0] - box_s[:, 0]), box_s[:, 2]).view(-1, 1)
        rlt_loc_y = torch.div((box_o[:, 1] - box_s[:, 1]), box_s[:, 3]).view(-1, 1)

        # scale information
        scl_info_w = torch.log(torch.div(box_o[:, 2], box_s[:, 2])).view(-1, 1)
        scl_info_h = torch.log(torch.div(box_o[:, 3], box_s[:, 3])).view(-1, 1)

        bos = torch.cat((rlt_loc_x, rlt_loc_y, scl_info_w, scl_info_h), 1)
        return bos


以上是我对LinkNet中三个模块的具体实现,完整代码请参考linknet-pytorch。

复现的LinkNet在Visual Genome数据集上效果如下:

模式 R@20 R@50 R@100
谓词分类 58.8 65.5 67.4
场景图分类 32.6 35.5 36.1
场景图检测 13.6 20.5 25.0

我目前正在将LinkNet中的Faster R-CNN由VGG替换为ResNet,完成之后模型效果应该还会有提升。

欢迎感兴趣的朋友一起交流: [email protected]

你可能感兴趣的:(LinkNet)