用PyTorch做一个本地图片搜索工具

前言

在比较大规模的iOS项目开发中,会遇到这样的场景,一个新需求使用的icon可能之前有,但是想找到实在是太难了。最近在学PyTorch,于是想到是否能用PyTorch做一个本地的图片搜索功能,经过一番搜索,非常可行

项目代码

ImageSearchApp: ImageSearchApp (gitee.com)使用QT做了一个简单的UI,安装好依赖,运行image_search_app.py即可。目前代码未作整理,仅供参考

image.png

上图搜索到的是VOC2007图库中的汽车,在实际iOS项目中我也有尝试,比如搜索叉号icon,在项目中找到了11个叉号图片。

Python依赖

  • PyQt5
  • pyperclip
  • torch
  • torchvision
  • PIL
  • PySide2

图片搜索原理

图片搜索主要分为以下几步

  1. 构建被搜索图片的特征库
  2. 提取输入图片的特征
  3. 将输入图片的特征和被搜索图片的特征进行比对,得出最相近的top K个结果

构建被搜索图片的特征库

抽取特征

直接基于PreTrain的模型进行特征抽取,我的代码中使用了resnet-18模型avgpool模块的特征输出,输出尺寸为512

model = models.resnet18(pretrained=True)
layer = model._modules.get('avgpool')

通过注册hook获取特征输出

image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
embedding = torch.zeros(1, self.number_features, 1, 1)
def copy_data(m ,i, o): embedding.copy_(o.data)
h = self.feature_layer.register_forward_hook(copy_data)
self.model(image)
h.remove()

特征抽取的完整代码在feature_extrator.py

特征持久化

为了避免每次都得重新计算特征,我使用了h5py保存特征值,使用图片文件的路径md5作为主key,分别保存pathfeature

h5_base_key = self.md5_of_path(img_full_path)

path_data = dbfile.create_dataset(h5_base_key + '/path', (1), dtype=h5py.special_dtype(vlen=str))
path_data[:] = img_full_path
dbfile[h5_base_key + "/feature"] = feature

这块的完整代码在batch_feature_processor.py

提取输入图片的特征

输入图片的特征提取直接使用feature_extrator.py即可

特征比对

比对主要使用余弦相似度来评估图片特征向量的相似度。在二维空间,余弦相似度可以理解为两个向量的夹角,夹角为0时,相似度最高,此时余弦为1,余弦的计算公式如下

cos(angle) = dot(VecA, VecB) / (|VecA| * |VecB|)

这个公式在高维度同样适用,比如我们输出的特征向量,是512维,计算代码如下

np.inner(feature_a.T, feature_b.T) / ((np.linalg.norm(feature_a, axis=0).reshape(-1, 1)) * ((np.linalg.norm(feature_b, axis=0).reshape(-1,1)).T))

np.inner表示内积,在高维空间,使用内积计算向量的点乘。np.linalg.norm则是计算第二范数,对应到二维空间就是计算长度。转置T是为了让矩阵的Shape匹配。
通过比对余弦值的大小就可以得到最匹配的图片啦~

你可能感兴趣的:(用PyTorch做一个本地图片搜索工具)