CLIP损失函数的理解

  参考资料:

  [一个写的相当好的教程]

  [CLIP huggingface源码:CLIPModel]

  [CLIP huggingface训练例程]

  这篇文章首先展示CLIP损失函数的两种底层实现代码,然后聊一聊自己的理解。

  说实话念硕士的时候没有接触过CLIP这个东西,来实习之后发现这个多模态的模型使用非常广泛,设计理念也是看后惊为天人。加上最近有探究任务研究CLIP,BLIP这些,遂决心把这个模型弄懂。参考资料1已经把CLIP的设计思想,原理,甚至是底层实现给讲清楚了,但是当我读到训练的损失函数那一段的时候还是产生了很大的疑问:作者说有两种方式来计算损失函数,一种较为简单,一种较为复杂。较为复杂的损失函数实现如下:

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()

  其中Cross_entropy也是作者自己实现的,看上去就是logsoftmax加上NLLloss:

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

  较为简单的损失函数的实现则是这样:nn.CrossEntropyLoss()(logits, torch.arange(batch_size))

  作者在下面进行了分析,我看完分析之后觉得... ... 作者的语气好像是在说这种较为简单的损失函数是有误的,在数据集中有同一张图片的多个相似caption的时候会明显犯错。那么,较为复杂的损失函数就是正确的了。以上是Tutorial里作者的实现,较为权威的另一种实现是huggingface团队Transformer库里的源码。由于CLIP模型的高度可定制性,huggingface团队实现了一个基类,也就是CLIPModel部分。并在需要训练的时候把loss设置为forward函数的第一个返回值,我们来看一下他们的实现:

image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)

text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)

# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()

loss = None
if return_loss:
    loss = clip_loss(logits_per_text)

  其中,clip_loss的实现如下:

# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.t())
    return (caption_loss + image_loss) / 2.0

  一开始的归一化比较好理解,logit_scale是一个超参数也好理解。最难理解的就是logits_per_text和logits_per_image这两个互为转置的矩阵。写这篇文章的时候我只能说自己弄懂了7分,原论文中有这么一段话:While standard image models jointly train an image feature extractor and a linear classifier to predict some label, CLIP jointly trains an image encoder and a text encoder to predict the correct pairings of a batch of (image, text) training examples. 即CLIP是学习(image, text)图文对之间的正确匹配的。这个正确匹配有两个对称的方面:1)对于每一个caption,和它吻合的图片得到label 1,和它不吻合的图片得到label 0。(这个对应于caption_loss)2)对于每一个image,和它吻合的caption得到label 1,和它不吻合的caption得到label 0。(这个对应于image_loss)而将两个loss相加除以2,得到的损失函数就同时考虑了两个方面了。如果一个模型在这两个方面都做得好,那么大概率是能够成功学习到correct pairings of a batch of (image, text) 的。

你可能感兴趣的:(深度学习,机器学习,python,人工智能,计算机视觉)