使用pytorch中的resnet预训练模型进行特征提取,以及查找相似图像

想通过缩略图找原图?之前p过的图像想找原图?汇报时使用压缩过的图像现在想找原图?如何从大量图像文件中快速找到与目标图像相似的那个?pytorch 只需要几行代码就可以搞定。

模型的选取

一般进行特征提取使用图像分类网络即可。参看上一篇 使用pytorch中的resnet预训练模型进行快速图像分类。

代码如下

提取查询图像和候选图像的特征,计算二者的余弦相似度,相似度越大则图像越相似。输出图像的路径,将相似的图像保存到指定目录下。

# load model
import torch
import torchvision
model = torchvision.models.resnet101(pretrained=True)
# or any of these variants
# resnet18, resnet34, resnet50, resnet101, resnet152
model.eval()


from PIL import Image
from torchvision import transforms
from tqdm import tqdm

# args
path_to_query = 'target.jpg' # given one query image
path_to_data = '/path/to/data/' # gallery images
BATCH_SIZE = 256
target_dir = 'targetdir' # we can save similar images in target directory
threshold = 0.8 # pick out if the cosine similarity > threshold.

# build dataloader
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_data = torchvision.datasets.ImageFolder(path_to_data, preprocess)
image_names = test_data.samples
data_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)

# load model to GPU
model.to('cuda')
count = 0
result = []

# test and query
with torch.no_grad():

    # load query image
    input_image = Image.open(path_to_query)
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
    input_batch = input_batch.to('cuda')
    
    # build feature extractor
    resnet50_feature_extractor = model
    resnet50_feature_extractor.fc = torch.nn.Linear(2048,2048)  # 512,512 2048,1024 ... 
    # the size varies for different models. Refer to official implementations for the size of feature maps
    
    # ---以下几行必须要有:---
    # torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight) # for parallel distributed training
    # torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight)
    torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
    torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
    for param in resnet50_feature_extractor.parameters():
        param.requires_grad = False
    # ---------------------
    
    # extract feature
    resnet50_feature_extractor = resnet50_feature_extractor.cuda()
    q_feature = resnet50_feature_extractor(input_batch)
    
    # load gallery images
    for (x, y) in tqdm(data_loader, desc="Evaluating", leave=False):
        x = x.to('cuda')
        y = y.to('cuda')
        
        # extract fature
        output = resnet50_feature_extractor(x)
        
        # calculate cosine similarity to query
        similarity = torch.cosine_similarity(q_feature, output, dim=1)
        for index in range(output.shape[0]):
            if similarity[index] > threshold:
                result.append(image_names[count*BATCH_SIZE+index][0])
        count += 1
        
# from shutil import copyfile
# import os
# os.makedirs(target_dir, exist_ok=True)
# for r in result:
#     copyfile(r, target_dir+'/'+r.split('/')[-1])

# print the results
for r in result:
    print(r)

参考

pytorch-resnet 提取特征

你可能感兴趣的:(pytorch,深度学习,图像处理)