【Paper Notes】CLIP: Learning Transferable Visual Models From Natural Language Supervision

论文链接

CLIP 的由来 Contrastive Language-Image Pre-training。


主线任务

这篇文章主要任务是解决多模态模型的预训练问题,其中更倾向于利用文本标注来预训练视觉模型。文章中比较亮眼的地方还提到了zero-shot的任务迁移。

从个人的角度总结,在文章中主要有两个贡献:

  1. 创建了一个非常大的数据集 WebImageText, 缩写是 WIT,包含400,000,000个(图像,文本)对。
  2. 提出了一种高效的训练方法,无需做整个文本的回归,将多模态训练转化为了匹配问题,计算文本和图像的匹配度。

模型架构

【Paper Notes】CLIP: Learning Transferable Visual Models From Natural Language Supervision_第1张图片

如上所示,模型包含两个部分,Text Encoder, Image Encoder。

Text Encoder 是一个Transformer模型,将文本转化为一个特征向量。图示中有多个文本(N个),所以会得到N个特征向量\bold{T}_i.

Image Encoder 是一个ViT或者CNN,其将图像编码成特征向量,其中对CNN的最后一层的pooling方式改进了以下,改成了Attention Pooling,即将 max_pooling 输出作为Query,将卷积得到的feature map 作为K,V。 图示中有N个图片,所以得到N个图像特征向量\bold{I}_i.

得到特征表示后,就可以计算匹配度(NxN的矩阵),利用匹配度训练整个网络。为了训练网络,论文实现了一种非常简洁的loss函数,如下所示:

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)

# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss   = (loss_i + loss_t)/2

其中loss_i 和 loss_t 比较巧妙,将对比损失化解为两个维度上的分类损失/cross entropy。

这时需要注意一个细节问题,如何保证\bold{T}_i, \bold{I}_i的维度是一致的呢?文章中使用的方法很简单,直接做一层线性层,然后用L2对特征进行归一化来对齐维度和规模。


实验结果

实验结果中,测试了以下zero-shot的性能,其中zero-shot的性能比部分few-shot的性能要好,这也可能和zero-shot的实现方式有关。如上图所示,为了适应到分类任务上,可以将所属的类别变成文本插入自然语言描述中,如 一个{类别}图片,然后计算和图像最好的匹配度作为分类结果。

【Paper Notes】CLIP: Learning Transferable Visual Models From Natural Language Supervision_第2张图片

 Linear Probe 表示是在最后一层加上分类的线性层微调得到的结果。

有意思的Zero-shot在domain shift上鲁棒性也很好,但是这感觉可能和数据集比较大有关。

你可能感兴趣的:(深度学习,人工智能)