CLIP在多模态领域蛮火的,其实之前有了解过的,好像还出了Chinese-CLIP,没具体去看,原理也比较简单。今天下午简单学了一波,小小记录一下。
CLIP论文链接、代码链接: https://github.com/openai/CLIP
论文48页,大部分篇幅是实验分析、相关工作介绍、动机、引用等。我就不细讲了,想看的直接看论文。我这里只讲讲关键的地方,CLIP的原理等。
补充概念:
- zero-shot learning: 训练集中没有某个类别的样本,但是如果我们可以学到一个牛逼的映射,这个映射好到我们即使在训练的时候没看到这个类,但是我们在遇到的时候依然能通过这个映射得到这个新类的特征。即: 对于 训练集 中 没有出现过 的 类别,模型能自动创造出相应的映射: XX -> YY。【既要马儿跑,还 不让 马儿吃草】
- few-shot learning: 训练集中,每个类别 都有样本,但都只是 少量样本(只有一个或几个) 。此时,我们可以在一个更大的数据集上或者利用knowledge graph、domain-knowledge 等方法,学到一个一般化的映射,然后再到小数据集上进行更新升级映射。【既要马儿跑,还不让马儿 多 吃草。】
- 传统的learning: 海量数据 + 反复训练 炼丹模式。【家里一座大草原,马儿马儿你随便吃。】
⭐ 数据: 4亿个网络公开的图文对。
⭐ 预训练任务: 对比学习,预测 N × N 对图文数据,将图片分类任务转换成图文匹配任务。
⭐ 输入: 一个batch有 N个图像文本对。
⭐ 输出: N × N 个相似度值。
⭐ 模型结构: 用的是双流结构,即图像和文本各自用一个encoder来编码。text encoder使用Transformer,image encoder用了2种模型,ResNet和Vision Transformer(ViT); ① 5种ResNet:ResNet-50, ResNet-101, EfficientNet-style的ResNet,包括RN50x4, RN50x16, RN50x64;② 3种ViT:ViT-B/32, ViT-B/16, ViT-L/14。
⭐ 训练的原理: 计算图像和文本模态之间的cosine similarity,使得batch内N个匹配的图文对相似度最大,不匹配的图文对相似度最小。
⭐ 对称的cross-entropy loss: 对比学习里面常用的方法,对称的算loss(即图像、文本的loss),再求和取平均。看下面的torch代码具体解释。
⭐ 预测的原理: 相当于prompt learning,也就是prompt engineering and ensembling那一块的东西。即做提示模版咯。如果单单输入一个单词,那没上下文,没法解决一次多义啊,抽出的特征肯定不咋地,所以给的提示越多越好(即上下文越丰富),且预训练的时候一般都不会只有一个单词的文本。这样可以结合多个提示模版的结果得到的嵌入取平均 (在嵌入层做平均,不是在概率层),再来做相似度。举个例子吧,假如你下游的数据集是猫科动物分类相关,你现在输入一张猪的图片,prompt模版应该这样:" A photo of a {label}, a type of animal", “A photo of a {label}, a type of cat” …总之可以构造任意多模版来集成。不过还是需要候选一下label来插入prompt模版里筛选的哟~
⭐ 一些细节:
forward计算出图像、文本的logits:
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
# logit_scale是可学习的超参
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
得到图像、文本的logits后,计算和label的loss:
【⭐ 不懂的话还可看这:CLIP算法的Loss详解 和 交叉熵CrossEntropy实现】
with torch.no_grad():
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
image_features, text_features, logit_scale = model(images, texts)
# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
# however, system RAM is easily exceeded and compute time becomes problematic
all_image_features.append(image_features.cpu())
all_text_features.append(text_features.cpu())
logit_scale = logit_scale.mean()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
batch_size = images.shape[0]
labels = torch.arange(batch_size, device=device).long()
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2