CLIP对比图文预训练 (Contrastive Language-Image Pretraining)论文阅读笔记

任务:

video captioning, 视频描述生成,简单来说就是给定一段视频(目前以几秒到几分钟的短视频为主),计算机输出描述这段视频的文字(目前以英文为主)。往往一个视频对应多个人工标注,这也是为训练时增添了一些鲁棒性,如:
CLIP对比图文预训练 (Contrastive Language-Image Pretraining)论文阅读笔记_第1张图片

网络模型:CLIP对比图文预训练 (Contrastive Language-Image Pretraining)论文阅读笔记_第2张图片

网络分成两部分:
1)文本特征提取:文本编码器可以是 transformer;
2) 图像特征提取:可以是resnet50等;
训练阶段:
训练数据是网络社交媒体上搜集的图像文本对。在训练阶段,对于一个batch 的数据,首先通过文本编码器和图像编码器,得到文本和图像的特征,接着将所有的文本(一个视频对应好多个标注句子)和图像特征分别计算内积,就能得到一个矩阵,然后从图像的角度看,行方向就是一个分类器,从文本角度看,列方向也是一个分类器。
而由于我们已经知道一个batch中的文本和图像的匹配关系,所以目标函数就是最大化同一对图像和文本特征的内积,也就是矩阵对角线上的元素,而最小化与不相关特征的内积。对图片嵌入特征和文本嵌入特征进行矩阵相乘。那么形成的打分矩阵上,对角线上都是配对的正样本对打分,而矩阵的其他元素,则是由同个batch内的图片和不配对的文本(相反亦然)组成的负样本。

def forward(self, image, text):
        image_features = self.encode_image(image) #编码image
        text_features = self.encode_text(text) #编码text

        # norm一下特征
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # 计算内积相似度logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ image_features.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

测试阶段:
CLIP对比图文预训练 (Contrastive Language-Image Pretraining)论文阅读笔记_第3张图片
在测试阶段,可以直接将训练好的CLIP用于其他数据集而不需要finetune。和训练阶段类似,首先将需要分类的图像经过编码器得到特征,然后对于目标任务数据集的每一个标签,或者你自己定义的标签,都构造一段对应的文本,如上图中的 dog 会改造成 “A photo of a dog”,以此类推。然后经过编码器得到文本和图像特征,接着将文本特征与图像特征做内积,内积最大对应的标签就是图像的分类结果。这就完成了目标任务上的 zero-shot 分类。

参考文献

https://blog.csdn.net/weixin_42772394/article/details/120688085?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164843634516780265465349%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=164843634516780265465349&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-120688085.142v5control,143v6register&utm_term=CLIP&spm=1018.2226.3001.4187

你可能感兴趣的:(算法,深度学习)