【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类

文章目录

  • 一、CLIP 模型
  • 二、准备
  • 三、加载模型
  • 四、查看图片处理器
  • 五、文本分词
  • 六、输入图片和文本,并可视化
  • 七、将图片和文字 encode 生成特征
  • 八、计算 cosine 相似度
  • 九、零样本进行图片分类
  • 十、编写函数进行图片分类
  • 十一、测试自己的函数
  • 十二、编写函数对多图片进行分类

项目地址:

https://github.com/biluko/Paper_Codes_for_fun/tree/master/CLIP

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第1张图片

一、CLIP 模型

CLIP(Contrastive Language-Image Pretraining)是由OpenAI开发的一个深度学习模型,用于处理图像和文本之间的联合表示。它的目标是将图像和文本嵌入到一个共享的向量空间中,使得相似的图像和文本在这个空间中距离较近,而不相似的图像和文本距离较远。

CLIP模型的特点在于它可以通过对图像和文本之间进行对比学习,来学习到一个通用的特征表示。在训练过程中,CLIP通过最大化相似图像和文本的相似性,并最小化不相似图像和文本的相似性来调整模型参数。这种对比学习的方法使得CLIP能够在多个任务上进行迁移学习,如图像分类、文本分类、图像生成等。

CLIP模型的应用非常广泛。通过将图像和文本映射到共享的向量空间,CLIP可以实现图像和文本之间的多模态检索和匹配。例如,通过将一张图片和一个描述该图片内容的文本查询进行编码,可以计算它们在向量空间中的距离,并找到与之相似的图片或文本。这为图像搜索、商品推荐、智能问答等应用提供了新的可能性。

CLIP模型的优势在于它不需要大量标注的训练数据,而是通过对比学习来学习通用的特征表示。这使得CLIP在跨领域和跨语言的应用上具有良好的泛化能力。此外,CLIP还能够理解和生成自然语言描述的图像,以及生成图像描述的文本,具备了一定的语义理解和生成能力。

总之,CLIP是一个强大的深度学习模型,能够将图像和文本嵌入到共享的向量空间中,并实现多模态的检索和匹配。它在图像和文本处理、多模态应用以及迁移学习等方面有着广泛的应用前景。

二、准备

包括下载 CLIP 依赖和将设置改为 GPU:

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第2张图片

import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)

在这里插入图片描述

三、加载模型

展示可选择的不同图片特征提取器:

import clip
clip.available_models()

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第3张图片
加载模型和图片处理器:

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("模型参数:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("输入图片尺寸:", input_resolution)
print("文本长度:", context_length)
print("词表大小:", vocab_size)

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第4张图片

四、查看图片处理器

这里调整图片大小 224 × 224 224 \times 224 224×224,中心裁剪,然后使用均值和标准差进行归一化,最后输出tensor向量:

preprocess

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第5张图片

五、文本分词

clip.tokenize("Hello World!")

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第6张图片

六、输入图片和文本,并可视化

import os 
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format='retina'

# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}
original_images=[]
images=[]
texts=[]
plt.figure(figsize=(16,5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()

七、将图片和文字 encode 生成特征

image_input = torch.tensor(np.stack(images)).cuda()

print(image_input.shape)
text_tokens = clip.tokenize(['This is '+ desc for desc in texts]).cuda()

with torch.no_grad():
  image_features = model.encode_image(image_input).float()
  text_features = model.encode_text(text_tokens).float()
print(image_features.shape)
print(text_features.shape)

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第7张图片

八、计算 cosine 相似度

image_features /= image_features.norm(dim = -1,keepdim = True)
text_features /= text_features.norm(dim = -1,keepdim = True)

similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize = (20, 14))
plt.imshow(similarity, vmin = 0.1, vmax = 0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize = 18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent = (i - 0.5, i + 0.5, -1.6, -0.6), origin = "lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha = "center", va = "center", size = 12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size = 20)

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第8张图片

九、零样本进行图片分类

数据集CIFAR100,就是使用相似度计算得分,然后softmax一下:

from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform = preprocess, download = True)

在这里插入图片描述
加上prompt 提示模板进行分类:

text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

计算相似度得分:

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim = -1, keepdim = True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)

可视化结果:

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace = 0.5)
plt.show()

十、编写函数进行图片分类

输入图片和供选择标签进行分类:

def show_result(image, probs, labels, label_name):
  plt.figure()
  plt.subplot(1, 2, 1)
  plt.imshow(image)
  plt.axis("off")

  plt.subplot(1, 2, 2)
  y = np.arange(probs.shape[-1])
  plt.grid()
  plt.barh(y, probs[0])
  plt.gca().invert_yaxis()
  plt.gca().set_axisbelow(True)
  plt.yticks(y, [label_name[index] for index in labels[0].numpy()])
  plt.xlabel("probability")

  plt.subplots_adjust(wspace = 0.5)
  plt.show()


def clip_classifier(image_path, choice_label, top_k = 5):
  # top_k小于choice_label数
  if top_k > len(choice_label):
    raise Exception('top_k大于候选标签数')

  # 读取图片
  image = Image.open(image_path).convert("RGB")
  # 输入特征
  text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
  text_tokens = clip.tokenize(text_descriptions).cuda()
  image_input = preprocess(image)
  image_input = image_input.clone().detach().cuda()

  with torch.no_grad():
    image_features = model.encode_image(image_input.unsqueeze(0)).float()
    text_features = model.encode_text(text_tokens).float()

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)

  #相似度得分
  text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
  top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
  show_result(image, top_probs, top_labels, choice_label)

十一、测试自己的函数

clip_classifier('R.jpg',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第9张图片

clip_classifier('Holmes.jpg',['Holmes','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第10张图片

十二、编写函数对多图片进行分类

def clip_classifier_m(image_dir, choice_label, top_k = 5):
  # image_dir不为文件夹
  if not os.path.isdir(image_dir):
    raise Exception(image_dir + ' 应该为一个图片文件夹')

  # top_k小于choice_label数
  if top_k > len(choice_label):
    raise Exception('top_k大于候选标签数')


  #读取图片
  original_images = []
  images = []

  for filename in [filename for filename in os.listdir(image_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    image = Image.open(os.path.join(image_dir, filename)).convert("RGB")

    original_images.append(image)
    images.append(preprocess(image))

  # 输入特征
  text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
  text_tokens = clip.tokenize(text_descriptions).cuda()
  image_input = torch.tensor(np.stack(images)).cuda()
  with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)

  # 相似度得分
  text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
  top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
  show_result_m(original_images, top_probs, top_labels, choice_label)


def show_result_m(images, probs, labels, label_name):
  length = len(images)
  num_row = length // 2

  plt.figure(figsize = (16, 16))

  for i, image in enumerate(images):
    plt.subplot(num_row, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(num_row, 4, 2 * i + 2)
    y = np.arange(probs.shape[-1])
    plt.grid()
    plt.barh(y, probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [label_name[index] for index in labels[i].numpy()])
    plt.xlabel("probability")

  plt.subplots_adjust(wspace = 1)
  plt.show()
clip_classifier_m('img',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','Holmes'])

【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_第11张图片

你可能感兴趣的:(计算机视觉,计算机视觉,分类,深度学习,CLIP,Github)