对比学习(Contrastive Learning) (1)

三篇论文

《Supervised Contrastive Learning》
《A Simple Framework for Contrastive Learning of Visual Representations》
《What Makes for Good Views for Contrastive Learning》

对比学习的思想起源于无监督学习,相比于监督学习算法,无监督学习由于没有标签的指导,训练过程学习样本的特征会更加困难。对比学习的核心思想就是通过数据增强构造原来样本的多样性,损失函数的设计用来拉进正样本与锚样本的距离,增大与负样本的距离,在这一过程中,网络更容易学到由源样本经过数据增强之后的多个样本所具有的共同特征,而这一特征对于源样本来说更可能是本质性的。

《A Simple Framework for Contrastive Learning of Visual Representations》SimCLR

论文提出了一种更简洁的对比学习算法,主要有三个贡献:

  • 使用不同组合形式的数据增强对于下游的预测任务非常重要
  • 在特征提取的encoder和对抗损失之间引入可学的多层感知机可以提高网络的学习能力
  • 在一个batch中,样本的数目越多越容易提高训练性能

这一工作也是后面诸多对比学习工作的基础。

网络框架

对比学习(Contrastive Learning) (1)_第1张图片

  1. 对于一个锚样本 x x x,使用随机的数据增强方式生成一对正样本,用 x ~ i \tilde{x}_i x~i x ~ j \tilde{x}_j x~j表示
  2. 一个特征提取网络encoder f ( ⋅ ) f(\cdot) f()用来提取 x ~ i \tilde{x}_i x~i x ~ j \tilde{x}_j x~j的特征,用来提高网络的泛化能力, h i = f ( x ~ i ) , h j = f ( x ~ j ) h_i=f(\tilde{x}_i),h_j=f(\tilde{x}_j) hi=f(x~i),hj=f(x~j)。特征提取的网络通常使用resnet。
  3. 在特征表示和对抗损失之间添加一个架构为多层感知机的投影网络,即 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) z_i=g(h_i)=W^{(2)}\sigma(W^{(1)}h_i) zi=g(hi)=W(2)σ(W(1)hi),这也是文章的贡献之一:在投影网络的输出端进行对比损失的计算要比直接在 f ( ⋅ ) f(\cdot) f()的输出计算更有用。

对三个贡献做出解释

1. 为什么不同形式的数据增强的组合有助于学到好的特征?
对比学习的目的是学到对于一个样本最核心的特征,如果使用单一的数据增强,比如只使用随机裁剪(random cropping),那么网络在训练过程就会认为颜色信息可能也是有用的,因为没有label来指导它学到下游任务的目标,网络无法提取对于下游更核心的特征。而采用多个数据增强的组合可以让网络认识到什么信息是不相关的,比如一个颜色失真的样本和一个高斯噪声的样本,这两个样本来源于同一个样本,网络在优化过程中需要认为他们两个着某些特征上是相同的,从而认识到颜色和噪声对于要提取的信息都是不重要的。
2. 为什么在encoder后面添加一个多层感知机可以提高学习能力?
z = g ( h ) z=g(h) z=g(h)的训练目的是增加对于数据变换的不变性,根据神经网络传统的学习方式,由于投影层处于较高的网络层次,网络学到的特征就更倾向于任务相关(high-level),低层的网络学到的更倾向于细节特征,如果没有投影层来学习高级特征,全部由encoder完成的话,encoder学到的特征在不同下游任务上的泛化能力会下降。
3. 为什么batchsize越大越容易收敛?
根据损失函数可以知道,当batchsize比较大的时候,意味着分母上的负样本数量也比较多,损失函数的目的是从一堆样本中找出锚样本,或者说,找出最能够区分锚样本与负样本的表征,当负样本数目多的时候,网络更容易排除什么信息对于该样本是不相关的,所以能够加快训练。

损失函数

L s e l f = ∑ i ∈ I L i s e l f = − ∑ i ∈ I log ⁡ exp ⁡ ( z i ⋅ z j ( i ) / τ ) ∑ α ∈ A ( i ) exp ⁡ ( z i ⋅ z α / τ ) \mathcal{L}^{self}=\sum_{i\in I}\mathcal{L}^{self}_i=-\sum_{i\in I}\log \frac{\exp (z_i \cdot z_{j(i)}/\tau)}{\sum_{\alpha \in A(i)}\exp (z_i \cdot z_{\alpha}/\tau)} Lself=iILiself=iIlogαA(i)exp(zizα/τ)exp(zizj(i)/τ)
其中 I I I表示当前的一个batch,算法实现的时候,首先是从定义好的大小为batchsize的样本数目中数据增强出两个batchsize的样本来(multiviewed batch),这个batchsize就是公式中的 I I I,对于一个batch中的每个样本,计算 L i s e l f \mathcal{L}^{self}_{i} Liself,其中 z i z_i zi是当前的样本(也称锚样本), z j ( i ) z_{j(i)} zj(i)是与 z i z_i zi同源的样本(由同一个样本数据增强得到), A ( i ) A(i) A(i)包含整个batchsize中除了当前样本之外的其他样本, τ \tau τ是温度系数,实际在训练的过程中,一个batch中的每个样本都会做一次锚样本。
这样说感觉上不是很直观,通过代码会加深对公式的理解。

原始无监督对比学习的代码及注释

重点在数据集的加载方式,loss的设计上

数据集准备

class ContrastiveLearningDataset:
    def __init__(self, root_folder):
        self.root_folder = root_folder

    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Return a set of data augmentation transformations as described in the SimCLR paper.
            定义数据增强的方式,选择训练的数据集
        """
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              GaussianBlur(kernel_size=int(0.1 * size)),
                                              transforms.ToTensor()])
        return data_transforms

    def get_dataset(self, name, n_views):
        valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
                                                              transform=ContrastiveLearningViewGenerator(
                                                                  self.get_simclr_pipeline_transform(32),
                                                                  n_views),
                                                              download=True),

                          'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                          transform=ContrastiveLearningViewGenerator(
                                                              self.get_simclr_pipeline_transform(96),
                                                              n_views),
                                                          download=True)}

        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_fn()



class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key.
        默认使用两个view做数据增强,即如果有一个batchsize为4 的样本[a1, b1, c1, d1]
        经过viewGenerator之后的形式为: [ a1, a2
                                      b1, b2
                                      c1, c2
                                      d1, d2]
        其中每一行表示同一个源样本产生的两个view样本。
    """

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]

特征提取的模型

class ResNetSimCLR(nn.Module):
    '''
        选择使用resnet-18还是resnet-50作为backbone,对应论文里面的encoder ==》 Enc(.)以及投影网络Projection Network ==》 Proj(i)
        其中encoder使用resnet的非全连接层部分,投影网络使用多层感知机
    '''
    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        return self.backbone(x)

损失函数设计

    def info_nce_loss(self, features):

        # 这里的labels用来做mask,方便后面与矩阵做逐元素相乘的时候筛选正样本和负样本,以batchsize=3为例,
        # 经过数据增强后一个batch的大小实际上为6,输入的features = [6, 128]
        # 最后生成的labels:tensor([[1., 0., 0., 1., 0., 0.],
        #                        [0., 1., 0., 0., 1., 0.],
        #                        [0., 0., 1., 0., 0., 1.],
        #                        [1., 0., 0., 1., 0., 0.],
        #                        [0., 1., 0., 0., 1., 0.],
        #                        [0., 0., 1., 0., 0., 1.]])
        labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(self.args.device)

        features = F.normalize(features, dim=1)

        # 计算相似度矩阵,即如果一个batch的输入样本为[ a1, a2
        #                                       b1, b2
        #                                       c1, c2]
        # 经过网络特征提取之后为:[a1 b1 c1 a2 b2 c2]
        # 相应地相似度矩阵为:[a1a1 a1b1 a1c1 a1a2 a1b2 a1c2
        #                  b1a1 b1b1 b1c1 b1a2 b1b2 b1c2
        #                  c1a1 c1b1 c1c1 c1a2 c1b2 c1c2
        #                  a2a1 a2b1 a2c1 a2a2 a2b2 a2c2
        #                  b2a1 b2b1 b2c1 b2a2 b2b2 b2c2
        #                  c2a1 c2b1 c2c1 c2a2 c2b2 c2c2]
        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)

        labels = labels[~mask].view(labels.shape[0], -1)
        # 此时的labels为:
        # tensor([[0., 0., 1., 0., 0.],
        #         [0., 0., 0., 1., 0.],
        #         [0., 0., 0., 0., 1.],
        #         [1., 0., 0., 0., 0.],
        #         [0., 1., 0., 0., 0.],
        #         [0., 0., 1., 0., 0.]])
        # 相比原来的labels删除了对角线上锚样本与自己做乘积的情况,
        # 对应在原相似度矩阵的位置上只保留label为1的数,相当于只保留了正样本与锚样本的乘积,即a1a2,b1b2,c1c2...
        # mask为:tensor([[ True, False, False, False, False, False],
        #               [False,  True, False, False, False, False],
        #               [False, False,  True, False, False, False],
        #               [False, False, False,  True, False, False],
        #               [False, False, False, False,  True, False],
        #               [False, False, False, False, False,  True]])
        # 相应地,在相似度矩阵上面排除锚样本与自己相乘的情况
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        # positives 保留正样本与锚样本的乘积:[a1a2
        #                                 b1b2
        #                                 c1c2
        #                                 a2a1
        #                                 b2b1
        #                                 c2c1]
        # negatives 保留锚样本与负样本的乘积:[a1b1 a1c1 a1b2 a1c2
        #                                b1a1 b1c1 b1a2 b1c2
        #                                c1a1 c1b1 c1a2 c1b2
        #                                a2b1 a2c1 a2b2 a2c2
        #                                b2a1 b2c1 b2a2 b2c2
        #                                c2a1 c2b1 c2a2 c2b2]
        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
        logits = torch.cat([positives, negatives], dim=1)
        # 将positives堆在negatives的前面,形如[a1a2 a1b1 a1c1 a1b2 a1c2
        #         #                        b1b2 b1a1 b1c1 b1a2 b1c2
        #         #                        c1c2 c1a1 c1b1 c1a2 c1b2
        #         #                        a2a1 a2b1 a2c1 a2b2 a2c2
        #         #                        b2b1 b2a1 b2c1 b2a2 b2c2
        #         #                        c2c1 c2a1 c2b1 c2a2 c2b2]
        # 最左边一列为infoloss的分子,右边为分子
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
        # labels = [0, 0, 0, 0, 0, 0],这里相当于交叉熵损失函数里面样本的真实标签为0
        # 因为对比损失函数跟交叉熵损失的计算形式是一样的,所以如果类别全部为0,表示的对于logits的每一行,都使用索引为0(也就是第一个)的元素作为分子
        logits = logits / self.args.temperature
        return logits, labels

训练过程

# 损失函数与交叉熵的形式一样
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)


    def train(self, train_loader):
        # pytorch的GradScaler和autocast使用混合精度可以节约内存空间,运行较大的batchsize
        scaler = GradScaler(enabled=self.args.fp16_precision)

        # save config file
        save_config_file(self.writer.log_dir, self.args)

        n_iter = 0
        logging.info("Start SimCLR training for {self.args.epochs} epochs.")
        logging.info("Training with gpu: {self.args.disable_cuda}.")

        for epoch_counter in range(self.args.epochs):
            for images, _ in tqdm(train_loader):
                images = torch.cat(images, dim=0)
                images = images.to(self.args.device)

                with autocast(enabled=self.args.fp16_precision):
                    # 对输入的正负样本图像提取的特征
                    features = self.model(images)
                    print(features.shape)
                    logits, labels = self.info_nce_loss(features)
                    loss = self.criterion(logits, labels)

                self.optimizer.zero_grad()

                scaler.scale(loss).backward()

                scaler.step(self.optimizer)
                scaler.update()

你可能感兴趣的:(机器学习,深度学习,算法)