【打卡】图像检索与重复图像识别3

【打卡】图像检索与重复图像识别3

文章目录

  • 【打卡】图像检索与重复图像识别3
        • 任务3:深度全局特征:

任务3:深度全局特征:

CNN/VIT模型特征提取:介绍CNN和VIT模型在图像特征提取中的应用,包括如何利用预训练模型提取图像的全局特征。

CLIP模型特征提取:讲解CLIP模型的原理和应用,包括如何将图像和文本的特征嵌入到同一个向量空间中,以及如何利用CLIP模型进行图像检索和分类。

深度全局特征的优缺点:讨论深度全局特征和传统算法的差异,包括特征表达能力、泛化能力、计算效率等方面。

步骤1:使用CNN模型预训练模型(如ResNet18)提取图片的CNN特征,计算query与dataset最相似的图片
步骤2:使用VIT模型预训练模型提取图片特征,计算query与dataset最相似的图片
步骤3:使用CLIP模型预训练模型提取图片特征,计算query与dataset最相似的图片
步骤4:分别将每种思路的计算结果提交到实践比赛地址:https://competition.coggle.club/
代码中,CLIP使用openAI发布的CLIP模型,VIT使用huggingface中VIT base 21k预训练

# 使用CNN模型预训练模型(如ResNet18)提取图片的CNN特征,计算query与dataset最相似的图片
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch
import numpy as np
from torch.nn.functional import normalize
import glob
import torchvision
import pandas as pd
import argparse
import os
import clip
from transformers import ViTImageProcessor, ViTModel
from PIL import Image

# 使用预训练模型提取图片的特征
def get_feat(args, img: torch.Tensor, model):
    if args.model in ['resnet18', 'resnet50', 'resnet101', 'resnet152']:
        img = img.cuda()
        feat = model(img)
        feat = normalize(feat) 
    return feat


# 计算query与dataset中图片的相似度
def get_sim(query_feat: torch.Tensor, dataset_feat: torch.Tensor):
    """
    param:
        query_feat: query图片的特征, shape: [query_num, feat_dim]
        dataset_feat: dataset图片的特征, shape: [dataset_num, feat_dim]
        return: 最佳匹配的图片索引, shape: [query_num, 1]
    """
    dis = torch.mm(query_feat, dataset_feat.t())
    # 计算每张query图片与dataset图片的相似度,取相似度最高的图片
    top_indix = torch.argmax(dis, dim=1)
    # 把top_indix转换成numpy数组
    top_indix = top_indix.cpu().numpy()
    return top_indix


# 读取图片
def read_img(path):
    img = Image.open(path)
    return img


# 生成csv文件,保存匹配结果
def save_csv(args, top_index):
    dataset_path = np.array(glob.glob('./dataset/*.jpg'))
    # 生成保存结果文件夹
    os.system('mkdir -p submit')
    # 提取所有对应元素
    top_paths = dataset_path[top_index]
    top_paths = [x.split('/')[-1] for x in top_paths]
    pd.DataFrame({
        'source':top_paths,
        'query': [x.split('/')[-1] for x in glob.glob('./query/*.jpg')]
    }).to_csv(os.path.join('./submit/', args.model + '.csv'), index=None)
    


# 定义参数
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='resnet18',\
                        choices=['resnet18', 'resnet50',\
                                  'vit_base', 'clip_vit_base_patch16_224'], help='model name')
    args = parser.parse_args()
    return args


# 定义模型
def get_model(args):
    if args.model == 'resnet18':
        # 使用resnet18模型提取图片特征
        model = models.resnet18(pretrained=True)
        # 去掉模型最后一层
        model = torch.nn.Sequential(*list(model.children())[:-1])
    if args.model == 'resnet50':
        # 使用resnet50模型提取图片特征
        model = models.resnet50(pretrained=True) 
        # 去掉模型最后一层
        model = torch.nn.Sequential(*list(model.children())[:-1])
    return model

def get_vit_model(args):
    if args.model == 'vit_base':
        # 使用vit模型提取图片特征
        preprocess = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
        model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
    if args.model == 'clip_vit_base_patch16_224':
        # 使用clip模型提取图片特征
        # model = torchvision.models.clip_vit_base_patch16_224(pretrained=True)
        model_name = 'ViT-B/16'
        model, preprocess = clip.load(model_name, device='cuda') 

    return model, preprocess

# 数据增强
def data_aug(img):
    w, h = img.size
    aug = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        # resize成224*224
        torchvision.transforms.Resize((224, 224)),
        # torchvision.transforms.RandAugment(),
        torchvision.transforms.RandomInvert(0.2),
        torchvision.transforms.RandomGrayscale(0.2),
        torchvision.transforms.RandomHorizontalFlip(p=0.5),
        torchvision.transforms.RandomVerticalFlip(p=0.5),
        torchvision.transforms.RandomAutocontrast(),
        torchvision.transforms.RandomRotation(10),
        torchvision.transforms.RandomAdjustSharpness(0.2),
        torchvision.transforms.RandomChoice([
            torchvision.transforms.Pad(10),
            torchvision.transforms.RandomResizedCrop(size=(w - 30, h - 30),
                                                     scale=(0.8, 1))
        ]),
    ])

    aug_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
 ])
    return aug(img).unsqueeze(0)


# 主函数
if __name__ == '__main__':
    # 传入args参数
    args = get_args()
    if args.model in ['resnet18', 'resnet50', 'resnet101', 'resnet152']:
        model = get_model(args)
    if args.model in ['vit_base', 'vit_large', 'clip_vit_base_patch16_224', 'clip_vit_large_patch16_224']:
        model, preprocess = get_vit_model(args)
    
    model = model.cuda()
    model.eval()
    # 使用torch.no_grad()包装模型推理过程
    with torch.no_grad():
        # 计算dataset文件夹中所有图的特征
        dataset_feat = [] # list of tensor
        for i, path in enumerate(glob.glob('./dataset/*.jpg')):
            if i%100 == 0:
                print('dataset: ', i)
            # if i==5:
            #     break
            img = read_img(path)
            
            if args.model == 'vit_base':
                img = preprocess(images=img, return_tensors="pt")
                img['pixel_values'] = img['pixel_values'].cuda()
                outputs = model(**img)
                last_hidden_states = outputs.last_hidden_state # torch.Size([1, 16*16+1, 768])
                # feat取CLS的特征
                feat = last_hidden_states[:, 0, :] # torch.Size([1, 768])
                feat /= feat.norm(dim=-1, keepdim=True)
            elif args.model == 'clip_vit_base_patch16_224':
                img = preprocess(img).unsqueeze(0).cuda()
                feat = model.encode_image(img)
                feat /= feat.norm(dim=-1, keepdim=True)
            else:
                img = data_aug(img)
                feat = get_feat(args, img, model)
            
            
            dataset_feat.append(feat)
        # 进行归一化
        dataset_feat = torch.stack(dataset_feat, dim=0)
        dataset_feat = dataset_feat.reshape(dataset_feat.shape[0], -1)
        # dataset_feat = normalize(dataset_feat)
        # 计算query文件夹中所有图像的特征
        query_feat = []
        for i, path in enumerate(glob.glob('./query/*.jpg')):
            if i%50 == 0:
                print('query: ', i)
            # if i==5:
            #     break
            img = read_img(path)
            if args.model == 'vit_base':
                img = preprocess(images=img, return_tensors="pt")
                img['pixel_values'] = img['pixel_values'].cuda()
                outputs = model(**img)
                last_hidden_states = outputs.last_hidden_state # torch.Size([1, 16*16+1, 768])
                feat = last_hidden_states[:, 0, :] # torch.Size([1, 768])
                feat /= feat.norm(dim=-1, keepdim=True)
            elif args.model == 'clip_vit_base_patch16_224':
                img = preprocess(img).unsqueeze(0).cuda()
                feat = model.encode_image(img)
                feat /= feat.norm(dim=-1, keepdim=True)
            else:
                img = data_aug(img)
                feat = get_feat(args, img, model)
            
            query_feat.append(feat)
        # 进行归一化
        query_feat = torch.stack(query_feat, dim=0)
        query_feat = query_feat.reshape(query_feat.shape[0], -1)
        # query_feat = normalize(query_feat)
        # 计算query与dataset中图片的相似度
        top_index = get_sim(query_feat, dataset_feat)
    
        # 保存匹配结果
        save_csv(args, top_index)

你可能感兴趣的:(python,数据挖掘竞赛,深度学习,机器学习,python)