《Supervised Contrastive Learning》
《A Simple Framework for Contrastive Learning of Visual Representations》
《What Makes for Good Views for Contrastive Learning》
对比学习的思想起源于无监督学习,相比于监督学习算法,无监督学习由于没有标签的指导,训练过程学习样本的特征会更加困难。对比学习的核心思想就是通过数据增强构造原来样本的多样性,损失函数的设计用来拉进正样本与锚样本的距离,增大与负样本的距离,在这一过程中,网络更容易学到由源样本经过数据增强之后的多个样本所具有的共同特征,而这一特征对于源样本来说更可能是本质性的。
论文提出了一种更简洁的对比学习算法,主要有三个贡献:
这一工作也是后面诸多对比学习工作的基础。
1. 为什么不同形式的数据增强的组合有助于学到好的特征?
对比学习的目的是学到对于一个样本最核心的特征,如果使用单一的数据增强,比如只使用随机裁剪(random cropping),那么网络在训练过程就会认为颜色信息可能也是有用的,因为没有label来指导它学到下游任务的目标,网络无法提取对于下游更核心的特征。而采用多个数据增强的组合可以让网络认识到什么信息是不相关的,比如一个颜色失真的样本和一个高斯噪声的样本,这两个样本来源于同一个样本,网络在优化过程中需要认为他们两个着某些特征上是相同的,从而认识到颜色和噪声对于要提取的信息都是不重要的。
2. 为什么在encoder后面添加一个多层感知机可以提高学习能力?
z = g ( h ) z=g(h) z=g(h)的训练目的是增加对于数据变换的不变性,根据神经网络传统的学习方式,由于投影层处于较高的网络层次,网络学到的特征就更倾向于任务相关(high-level),低层的网络学到的更倾向于细节特征,如果没有投影层来学习高级特征,全部由encoder完成的话,encoder学到的特征在不同下游任务上的泛化能力会下降。
3. 为什么batchsize越大越容易收敛?
根据损失函数可以知道,当batchsize比较大的时候,意味着分母上的负样本数量也比较多,损失函数的目的是从一堆样本中找出锚样本,或者说,找出最能够区分锚样本与负样本的表征,当负样本数目多的时候,网络更容易排除什么信息对于该样本是不相关的,所以能够加快训练。
L s e l f = ∑ i ∈ I L i s e l f = − ∑ i ∈ I log exp ( z i ⋅ z j ( i ) / τ ) ∑ α ∈ A ( i ) exp ( z i ⋅ z α / τ ) \mathcal{L}^{self}=\sum_{i\in I}\mathcal{L}^{self}_i=-\sum_{i\in I}\log \frac{\exp (z_i \cdot z_{j(i)}/\tau)}{\sum_{\alpha \in A(i)}\exp (z_i \cdot z_{\alpha}/\tau)} Lself=i∈I∑Liself=−i∈I∑log∑α∈A(i)exp(zi⋅zα/τ)exp(zi⋅zj(i)/τ)
其中 I I I表示当前的一个batch,算法实现的时候,首先是从定义好的大小为batchsize的样本数目中数据增强出两个batchsize的样本来(multiviewed batch),这个batchsize就是公式中的 I I I,对于一个batch中的每个样本,计算 L i s e l f \mathcal{L}^{self}_{i} Liself,其中 z i z_i zi是当前的样本(也称锚样本), z j ( i ) z_{j(i)} zj(i)是与 z i z_i zi同源的样本(由同一个样本数据增强得到), A ( i ) A(i) A(i)包含整个batchsize中除了当前样本之外的其他样本, τ \tau τ是温度系数,实际在训练的过程中,一个batch中的每个样本都会做一次锚样本。
这样说感觉上不是很直观,通过代码会加深对公式的理解。
重点在数据集的加载方式,loss的设计上
数据集准备
class ContrastiveLearningDataset:
def __init__(self, root_folder):
self.root_folder = root_folder
@staticmethod
def get_simclr_pipeline_transform(size, s=1):
"""Return a set of data augmentation transformations as described in the SimCLR paper.
定义数据增强的方式,选择训练的数据集
"""
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * size)),
transforms.ToTensor()])
return data_transforms
def get_dataset(self, name, n_views):
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(32),
n_views),
download=True),
'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(96),
n_views),
download=True)}
try:
dataset_fn = valid_datasets[name]
except KeyError:
raise InvalidDatasetSelection()
else:
return dataset_fn()
class ContrastiveLearningViewGenerator(object):
"""Take two random crops of one image as the query and key.
默认使用两个view做数据增强,即如果有一个batchsize为4 的样本[a1, b1, c1, d1]
经过viewGenerator之后的形式为: [ a1, a2
b1, b2
c1, c2
d1, d2]
其中每一行表示同一个源样本产生的两个view样本。
"""
def __init__(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_views
def __call__(self, x):
return [self.base_transform(x) for i in range(self.n_views)]
特征提取的模型
class ResNetSimCLR(nn.Module):
'''
选择使用resnet-18还是resnet-50作为backbone,对应论文里面的encoder ==》 Enc(.)以及投影网络Projection Network ==》 Proj(i)
其中encoder使用resnet的非全连接层部分,投影网络使用多层感知机
'''
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
def _get_basemodel(self, model_name):
try:
model = self.resnet_dict[model_name]
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
else:
return model
def forward(self, x):
return self.backbone(x)
损失函数设计
def info_nce_loss(self, features):
# 这里的labels用来做mask,方便后面与矩阵做逐元素相乘的时候筛选正样本和负样本,以batchsize=3为例,
# 经过数据增强后一个batch的大小实际上为6,输入的features = [6, 128]
# 最后生成的labels:tensor([[1., 0., 0., 1., 0., 0.],
# [0., 1., 0., 0., 1., 0.],
# [0., 0., 1., 0., 0., 1.],
# [1., 0., 0., 1., 0., 0.],
# [0., 1., 0., 0., 1., 0.],
# [0., 0., 1., 0., 0., 1.]])
labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(self.args.device)
features = F.normalize(features, dim=1)
# 计算相似度矩阵,即如果一个batch的输入样本为[ a1, a2
# b1, b2
# c1, c2]
# 经过网络特征提取之后为:[a1 b1 c1 a2 b2 c2]
# 相应地相似度矩阵为:[a1a1 a1b1 a1c1 a1a2 a1b2 a1c2
# b1a1 b1b1 b1c1 b1a2 b1b2 b1c2
# c1a1 c1b1 c1c1 c1a2 c1b2 c1c2
# a2a1 a2b1 a2c1 a2a2 a2b2 a2c2
# b2a1 b2b1 b2c1 b2a2 b2b2 b2c2
# c2a1 c2b1 c2c1 c2a2 c2b2 c2c2]
similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape
# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
labels = labels[~mask].view(labels.shape[0], -1)
# 此时的labels为:
# tensor([[0., 0., 1., 0., 0.],
# [0., 0., 0., 1., 0.],
# [0., 0., 0., 0., 1.],
# [1., 0., 0., 0., 0.],
# [0., 1., 0., 0., 0.],
# [0., 0., 1., 0., 0.]])
# 相比原来的labels删除了对角线上锚样本与自己做乘积的情况,
# 对应在原相似度矩阵的位置上只保留label为1的数,相当于只保留了正样本与锚样本的乘积,即a1a2,b1b2,c1c2...
# mask为:tensor([[ True, False, False, False, False, False],
# [False, True, False, False, False, False],
# [False, False, True, False, False, False],
# [False, False, False, True, False, False],
# [False, False, False, False, True, False],
# [False, False, False, False, False, True]])
# 相应地,在相似度矩阵上面排除锚样本与自己相乘的情况
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# assert similarity_matrix.shape == labels.shape
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# positives 保留正样本与锚样本的乘积:[a1a2
# b1b2
# c1c2
# a2a1
# b2b1
# c2c1]
# negatives 保留锚样本与负样本的乘积:[a1b1 a1c1 a1b2 a1c2
# b1a1 b1c1 b1a2 b1c2
# c1a1 c1b1 c1a2 c1b2
# a2b1 a2c1 a2b2 a2c2
# b2a1 b2c1 b2a2 b2c2
# c2a1 c2b1 c2a2 c2b2]
# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
# 将positives堆在negatives的前面,形如[a1a2 a1b1 a1c1 a1b2 a1c2
# # b1b2 b1a1 b1c1 b1a2 b1c2
# # c1c2 c1a1 c1b1 c1a2 c1b2
# # a2a1 a2b1 a2c1 a2b2 a2c2
# # b2b1 b2a1 b2c1 b2a2 b2c2
# # c2c1 c2a1 c2b1 c2a2 c2b2]
# 最左边一列为infoloss的分子,右边为分子
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
# labels = [0, 0, 0, 0, 0, 0],这里相当于交叉熵损失函数里面样本的真实标签为0
# 因为对比损失函数跟交叉熵损失的计算形式是一样的,所以如果类别全部为0,表示的对于logits的每一行,都使用索引为0(也就是第一个)的元素作为分子
logits = logits / self.args.temperature
return logits, labels
训练过程
# 损失函数与交叉熵的形式一样
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
def train(self, train_loader):
# pytorch的GradScaler和autocast使用混合精度可以节约内存空间,运行较大的batchsize
scaler = GradScaler(enabled=self.args.fp16_precision)
# save config file
save_config_file(self.writer.log_dir, self.args)
n_iter = 0
logging.info("Start SimCLR training for {self.args.epochs} epochs.")
logging.info("Training with gpu: {self.args.disable_cuda}.")
for epoch_counter in range(self.args.epochs):
for images, _ in tqdm(train_loader):
images = torch.cat(images, dim=0)
images = images.to(self.args.device)
with autocast(enabled=self.args.fp16_precision):
# 对输入的正负样本图像提取的特征
features = self.model(images)
print(features.shape)
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()