PixPro自监督论文解读

PixPro自监督论文解读_第1张图片

PixPro是第一个通过像素级对比学习来进行特征表示学习

上图是整个算法流程图,接下来进行详细解析

前向传播

input是输入图像,维度尺寸是(b, c, h, w)

augmentation:通过对同一张input进行随机大小、位置裁剪并缩放到统一大小224*224,并基于一定概率下进行随机水平翻转、color distortion、高斯模糊和solarization操作,最后生成两张不同视图view #1和view #2,大小都是(b, c, 224, 224)

backbone+projection:view #1和view #2分别送入两个网络分支,上下两分支中都含有相同结构的backbone+projection模块,其中backbone模块使用了Resnet,输出最后一层特征图,大小为(b, c1, 7, 7)。

projection模块是一个conv1*1+BN+Relu+conv1*1结构,先进行升维,再降维到256大小,这样就得到了两个输出大小为(b, 256, 7, 7)的特征$x$和$x^{,}$,projection模块代码如下:

class MLP2d(nn.Module):
    def __init__(self, in_dim, inner_dim=4096, out_dim=256):
        super(MLP2d, self).__init__()

        self.linear1 = conv1x1(in_dim, inner_dim)
        self.bn1 = nn.BatchNorm2d(inner_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.linear2 = conv1x1(inner_dim, out_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.linear2(x)

        return x
        
def conv1x1(in_planes, out_planes):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=True)

PPM:是一个自注意力模块,针对(b, 256, 7, 7)的输入特征图$x$

PixPro自监督论文解读_第2张图片

首先根据cosine相似度计算出attention图,大小为(b, 49, 49),表示每个特征点与其他特征点的相似度。再对输入特征图进行特征融合,得到输出大小为(b, 256, 7, 7)的特征图$y$,PPM代码如下:

    def featprop(self, feat):
        N, C, H, W = feat.shape

        # Value transformation
        feat_value = self.value_transform(feat)  # 1*1卷积操作
        feat_value = F.normalize(feat_value, dim=1)
        feat_value = feat_value.view(N, C, -1)

        # Similarity calculation
        feat = F.normalize(feat, dim=1)

        # [N, C, H * W]
        feat = feat.view(N, C, -1)

        # [N, H * W, H * W]
        attention = torch.bmm(feat.transpose(1, 2), feat)
        attention = torch.clamp(attention, min=self.pixpro_clamp_value)
        if self.pixpro_p < 1.:
            attention = attention + 1e-6
        attention = attention ** self.pixpro_p  # pixpro_p控制注意力的范围,默认为1

        # [N, C, H * W]
        feat = torch.bmm(feat_value, attention.transpose(1, 2))

        return feat.view(N, C, H, W)

Loss:计算$x^,$和$y$之间的loss。$x^,$和$y$的空间位置示意图如下所示:

PixPro自监督论文解读_第3张图片

在数据增强augmentation过程中,可以获取裁剪图像的左上角和右下键坐标,由于输出特征图$x^,$和$y$大小为(b, 256, 7, 7),所以每个特征图中有7*7个特征点,根据插值法就可以获取输出特征图$x^,$和$y$的每个特征点的空间坐标,大小为(b, 2, 7, 7)。

首先计算出不同视图中每个特征点彼此之间的距离,可以得到大小为(b, 49, 49)的距离矩阵D,步骤如下:

  1. 特征图$x^{,}$的x坐标$X_{x^{,}}$:(b, 7, 7)->(b, 49, 1), y坐标$Y_{x^{,}}$:(b, 7, 7)->(b, 49, 1)
  2. 特征图$y$中x坐标$X_{y}$:(b, 7, 7)->(b, 1, 49), y坐标$Y_{y}$:(b, 7, 7)->(b, 1,49)
  3. 距离矩阵D=$\sqrt{(X_{x^{,}}-X_{y})^2+(Y_{x^{,}}-Y_{y})^2}/max\_bin$(max_bin是相邻特征点之间的最大距离,目的是为了"归一化")

不同视图中距离较近的特征点特征应该具有一致性consistency,所以对距离特征D根据阈值ratio进行二分来获取距离较近的特征点掩码M=(D

再计算出$x^,$和$y$的特征相似度图logit,大小为(b, 49, 49),这步与PPM中计算注意力相似度类似

最后根据特征相似图和掩码矩阵计算loss:

$loss = logit * M$

整个loss计算完整过程的代码如下:

def regression_loss(q, k, coord_q, coord_k, pos_ratio=0.5):
    """ q, k: N * C * H * W
        coord_q, coord_k: N * 4 (x_upper_left, y_upper_left, x_lower_right, y_lower_right)
    """
    N, C, H, W = q.shape
    # [bs, feat_dim, 49]
    q = q.view(N, C, -1)
    k = k.view(N, C, -1)

    # generate center_coord, width, height
    # [1, 7, 7]
    x_array = torch.arange(0., float(W), dtype=coord_q.dtype, device=coord_q.device).view(1, 1, -1).repeat(1, H, 1)
    y_array = torch.arange(0., float(H), dtype=coord_q.dtype, device=coord_q.device).view(1, -1, 1).repeat(1, 1, W)
    # [bs, 1, 1]
    q_bin_width = ((coord_q[:, 2] - coord_q[:, 0]) / W).view(-1, 1, 1)
    q_bin_height = ((coord_q[:, 3] - coord_q[:, 1]) / H).view(-1, 1, 1)
    k_bin_width = ((coord_k[:, 2] - coord_k[:, 0]) / W).view(-1, 1, 1)
    k_bin_height = ((coord_k[:, 3] - coord_k[:, 1]) / H).view(-1, 1, 1)
    # [bs, 1, 1]
    q_start_x = coord_q[:, 0].view(-1, 1, 1)
    q_start_y = coord_q[:, 1].view(-1, 1, 1)
    k_start_x = coord_k[:, 0].view(-1, 1, 1)
    k_start_y = coord_k[:, 1].view(-1, 1, 1)

    # [bs, 1, 1]
    q_bin_diag = torch.sqrt(q_bin_width ** 2 + q_bin_height ** 2)
    k_bin_diag = torch.sqrt(k_bin_width ** 2 + k_bin_height ** 2)
    max_bin_diag = torch.max(q_bin_diag, k_bin_diag)

    # [bs, 7, 7]
    center_q_x = (x_array + 0.5) * q_bin_width + q_start_x
    center_q_y = (y_array + 0.5) * q_bin_height + q_start_y
    center_k_x = (x_array + 0.5) * k_bin_width + k_start_x
    center_k_y = (y_array + 0.5) * k_bin_height + k_start_y

    # [bs, 49, 49]
    dist_center = torch.sqrt((center_q_x.view(-1, H * W, 1) - center_k_x.view(-1, 1, H * W)) ** 2
                             + (center_q_y.view(-1, H * W, 1) - center_k_y.view(-1, 1, H * W)) ** 2) / max_bin_diag
    pos_mask = (dist_center < pos_ratio).float().detach()

    # [bs, 49, 49]
    logit = torch.bmm(q.transpose(1, 2), k)

    loss = (logit * pos_mask).sum(-1).sum(-1) / (pos_mask.sum(-1).sum(-1) + 1e-6)

    return -2 * loss.mean()

反向传播

下分支网络不参与直接训练,其中所有的权重参数不具有梯度值。其参数$param\_k$更新方式是基于上分支网络参数$param\_q$动量更新。训练开始前,上下分支网络初始权重保持一致。

$$ param\_k.data = param\_k.data * momentum + param\_q.data * (1-momentum) $$

其中,momentum是动量值,整个训练过程从0.99逐渐增大到1.0

实验

优化器:LARS,weight_decay=1e-5

lr_scheduler:cosine, warmup

total_batchsize:1024

world size:8 V100 GPUs

与其他基于实例级自监督算法在下游检测分割任务上的比较结果

PixPro自监督论文解读_第4张图片

不同超参数下的实验结果

PixPro自监督论文解读_第5张图片

PixPro和ProContrast结果比较

PixPro自监督论文解读_第6张图片

结合实例级模块的结果

PixPro自监督论文解读_第7张图片

有无FPN、head、实例级模块的实验比较结果

PixPro自监督论文解读_第8张图片

你可能感兴趣的:(深度学习pytorch图像识别)