[深度学习 - 实战项目] 以图搜图Resnet+LSH-特征编码/图像检索/相似度计算

参考代码来源于 http://github.com/yinhaoxs/ImageRetrieval-LSH

以图搜图

1. 写在最前面

入职新公司以后一直在搞项目,没什么时间写博客。
最近一个项目是以图搜图项目,主要用到的技术就是目标检测(yolo)+图像检索(ResNet+LSH)。
目标检测就不用多说了,成熟和现成的代码一抓一大把,主要问题就是在优化提升精度和性能上的摸索。
图像检索的技术也挺多,但是网上的资源相对较少,所以记录一下这段时间用到的一个代码。

最开始直接看到的是这个作者的ImageRetrieval-LSH代码。里面说明文档也比较少,所以记录下我看这个源码的过程。
这个代码包括flask部署和利用LSH提高检索速度都写好了,非常的完善,只要给被搜图片的目录和数据库目录就可以进行检索,模型也是训练好的,数据集用的retrieval-SfM-120K(这个数据集38GB我在官网下不下来,网络带宽不行但是下载了标签(.pkl)文件并解析了一下训练集的类型,后面我会讲一下)。
其中的特征编码模型用的是:https://github.com/filipradenovic/cnnimageretrieval-pytorch

这个以图搜图和人脸识别技术其实很像,可以说是一样。无非就是提取特征,然后进行相似度计算。所以相关的技术有ReID,Arcface,以及我在调研的时候有看到一个素描草图的图像匹配的研究。
这里有相关图像检索Image Retrieval知识资料全集

2. 源码解析

(1)跑通代码-即测试一下自己的图片 demo.py

我只用到里面的编码和检索部分。运行demo.py。直接跑通这部分,然后缺什么库函数去pip install就行了。代码这里,权重包需要科学下载,或者按下面百度云链接


2022.4.30更新,由于近期实在太忙,在公司也不方便发文件,各位找我要权重包的我一回家就忘了。。。所以我把作者的权重包也上传一份到百度云,大家自取即可。链接:https://pan.baidu.com/s/1atZ9fETMlvP45c87JnDbSw
提取码:p1cl

from utils.retrieval_feature import AntiFraudFeatureDataset
from utils.retrieval_index import EvaluteMap


if __name__ == '__main__':
    hash_size = 0
    input_dim = 2048
    num_hashtables = 1
    img_dir = 'ImageRetrieval/data' #存放所有图像库的图片
    test_img_dir = './images' # 待检索的图像
    network = './weights/gl18-tl-resnet50-gem-w-83fdc30.pth' # 模型权重
    #下面这几个好像没有用,不管他
    out_similar_dir = './output/similar'
    out_similar_file_dir = './output/similar_file'
    all_csv_file = './output/aaa.csv'

    feature_dict, lsh = AntiFraudFeatureDataset(img_dir, network).constructfeature(hash_size, input_dim, num_hashtables)
    test_feature_dict = AntiFraudFeatureDataset(test_img_dir, network).test_feature()
    EvaluteMap(out_similar_dir, out_similar_file_dir, all_csv_file).retrieval_images(test_feature_dict, lsh, 3)

(2)特征编码

代码首先对img_dir中的所有图片进行特征提取:feature_dict, lsh = AntiFraudFeatureDataset(img_dir, network).constructfeature(hash_size, input_dim, num_hashtables)

返回的feature_dict就是图片特征。(可以直接用余弦相似度进行相似计算)
但是这里还通过LSH对每张图片特征图进行0,1编号,所在这里后面用来图片检索的不是feature_dict,而是lsh,(应该是加速后面图片检索时候的速度)

进到特征编码那块代码retrieval_feature.py ,里面主要对图片进行编码的函数对象是AntiFraudFeatureDataset

首先前面一大段到net.eval(),都是加载网络模型,可以看到模型选择有很多参数,这些参数对应网络的结构设置,(后面如果用自己的数据对自己的特征编码模型进行训练的话,要根据使用的不同模型参数进行修改)
这个函数ImageProcess是遍历目录底下的全部图片,并将他们的路径保存在数组中。
然后再这个函数extract_vectors中提取图像特征。(在这个目录底下ImageRetrieval-LSH/cirtorch/networks/imageretrievalnet.py )主要也不需要怎么做修改,除非说你要修改一下图片的dataloader(这里是通过将所有图片路径保存下来做的dataset,因为每张图片的尺寸可以不一样,Resnet网络的最后通过一个全连接层输出1 * 2048特征图。)

所以这里出来的vecs是N张图片的特征编码,每个特征编码是1 * 2048。

然后对待检索图像进行特征编码是在这里test_feature_dict = AntiFraudFeatureDataset(test_img_dir, network).test_feature()这里他和上面的区别就是没有做LSH,得到的每个特征编码在后面进行检索的时候对进行LSH。(详细对LSH我也没怎么了解,因为我后面部署的时候用的是java进行相似度计算,所以这部分我没怎么了解,好像是Python加速检索速度的,具体这个代码里面用到的模块可以看这里https://github.com/kayzhu/LSHash)

def constructfeature(self, hash_size, input_dim, num_hashtables):
        multiscale = '[1]'
        print(">> Loading network:\n>>>> '{}'".format(self.network))
		
        state = torch.load(self.network)

        net_params = {}
        net_params['architecture'] = state['meta']['architecture']
        net_params['pooling'] = state['meta']['pooling']
        net_params['local_whitening'] = state['meta'].get('local_whitening', False)
        net_params['regional'] = state['meta'].get('regional', False)
        net_params['whitening'] = state['meta'].get('whitening', False)
        net_params['mean'] = state['meta']['mean']
        net_params['std'] = state['meta']['std']
        net_params['pretrained'] = False
        # network initialization
        net = init_network(net_params)
        net.load_state_dict(state['state_dict'])
        print(">>>> loaded network: ")
        print(net.meta_repr())
        # setting up the multi-scale parameters
        ms = list(eval(multiscale))
        print(">>>> Evaluating scales: {}".format(ms))
        # moving network to gpu and eval mode
        if torch.cuda.is_available():
            net.cuda()
        net.eval()


        # set up the transform 数据预处理
        normalize = transforms.Normalize(
            mean=net.meta['mean'],
            std=net.meta['std']
        )
        transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        # extract database and query vectors 对图片进行编码提取数据库图片特征
        print('>> database images...')
        images = ImageProcess(self.img_dir).process()
        vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms)
        feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
        # index 
        lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables))
        for img_path, vec in feature_dict.items():
            lsh.index(vec.flatten(), extra_data=img_path)

        # ## 保存索引模型
        # with open(self.feature_path, "wb") as f:
        #     pickle.dump(feature_dict, f)
        # with open(self.index_path, "wb") as f:
        #     pickle.dump(lsh, f)

        print("extract feature is done")
        return feature_dict, lsh

(3)图像检索

这里图像检索这块我没怎么改动,因为只是测试一下自己训练后的模型的效果比较方便查看用的。所以我只是修改了输出的数量。
这里如果要输出多个Top,要自己多加几个,(也可以自己写个循环,我比较懒,没有写)然后后面我还显示出了得分情况,(因为后面要进行模型的对比)

    def find_similar_img_gyz(self, feature_dict, lsh, num_results):
        for q_path, q_vec in feature_dict.items():

            try:
                response = lsh.query(q_vec.flatten(), distance_func="cosine")  # , num_results=int(num_results)
                # print(response[0][1])
                # print(np.rint(100 * (1 - response[0][1])))
                query_img_path0 = response[0][0][1]
                query_img_path1 = response[1][0][1]
                query_img_path2 = response[2][0][1]
                query_img_path3 = response[3][0][1]
                query_img_path4 = response[4][0][1]
                score_img_path0 = response[0][1]
                score_img_path1 = response[1][1]
                score_img_path2 = response[2][1]
                score_img_path3 = response[3][1]
                score_img_path4 = response[4][1]

                # score0 = response[0][1]
                # score0 = np.rint(100 * (1 - score0))
                print('**********************************************')
                print('input img: {}'.format(q_path))
                print('query0 img: {}'.format(query_img_path0),
                      ' score:{}'.format(np.rint(100 * (1 - score_img_path0))))
                print('query1 img: {}'.format(query_img_path1),
                      ' score:{}'.format(np.rint(100 * (1 - score_img_path1))))
                print('query2 img: {}'.format(query_img_path2),
                      ' score:{}'.format(np.rint(100 * (1 - score_img_path2))))
                print('query3 img: {}'.format(query_img_path3),
                      ' score:{}'.format(np.rint(100 * (1 - score_img_path3))))
                print('query4 img: {}'.format(query_img_path4),
                      ' score:{}'.format(np.rint(100 * (1 - score_img_path4))))
            except:
                continue

3. 训练自己的数据集

(1)训练参数配置

特征编码模型主要来源于这里:https://github.com/filipradenovic/cnnimageretrieval-pytorch所以训练自己的数据也是根据这里面的代码。主要模块就是cirtorch/example/train.py
首先就是修改一下你要使用的模型参数,我这里用的resnet50,损失用的contrastive(因为用tripletLoss的时候结果出了点问题哈哈哈哈正常来讲应该tripletLoss会更好吧)

# network architecture and initialization options
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet101)')
parser.add_argument('--pool', '-p', metavar='POOL', default='gem', choices=pool_names,
                    help='pooling options: ' +
                        ' | '.join(pool_names) +
                        ' (default: gem)')
parser.add_argument('--local-whitening', '-lw', dest='local_whitening', action='store_true',
                    help='train model with learnable local whitening (linear layer) before the pooling')
parser.add_argument('--regional', '-r', dest='regional', action='store_true',
                    help='train model with regional pooling using fixed grid')
parser.add_argument('--whitening', '-w', dest='whitening', action='store_true',
                    help='train model with learnable whitening (linear layer) after the pooling')
parser.add_argument('--not-pretrained', dest='pretrained', action='store_false',
                    help='initialize model with random weights (default: pretrained on imagenet)')
parser.add_argument('--loss', '-l', metavar='LOSS', default='contrastive',
                    choices=loss_names,
                    help='training loss options: ' +
                        ' | '.join(loss_names) +
                        ' (default: contrastive)')
parser.add_argument('--loss-margin', '-lm', metavar='LM', default=0.7, type=float,
                    help='loss margin: (default: 0.7)')

(2)训练数据的准备

这里训练数据用的是retrieval-SfM-120k但是因为数据集38个GB,网速不行,在外网上下不下来,所以气急败坏的我直接看他的标签文件retrieval-SfM-120k.pkl
这个文件就是一个字典格式文件,大概分了几层如下,因为我没有准备验证集和测试集,所以训练时候测试和验证那部分我直接删去了(主要因为测试集的格式和训练集不一样,我懒得再去解析另一个数据集的格式)。

{ train : {
cids : [ ], cluster : [ ], qidxs : [ ], pidxs : [ ]
},
val : {…}
}

① cids:主要用来存放所有图片的路径,所以不管你图片存放在哪,只要有图片路径即可。数组长度就是总的图片数量。
② cluster:这个是存放该图片的类别,数组的长度和cids一样,类别一一对应cids的图片(retrieval-SfM-120k 是有713个建筑物所以是713类,依据自己的数据集而定,我的数据集每对图片都是一个类,所以有几千个类别)。
③ qidxs&pidxs:这个qidx是存放查询的query图片,对应位置的pidx是存放和他匹配的positive图片,这样就形成了一对正样例。然后这两个数组存放的是前面cids的索引index,对应的是cids[qidxs[1]] -> cids[pidxs[1]]。

所以自己准备一个自己数据集的pkl文件,就可以训练了。


更新一下制作标签的代码(在这一节最后,自己写的,所以有的粗糙。)
把你准备训练的数据放在一个目录下,例如:dirs = 'cirtorch/ImageRetrieval_dataset/train'如下图。train下每个文件夹都是相似图片的集合。(我自己每个相似图只有两张,所以每个qp文件夹下只有两张,这份代码应该可以支持多张相似图,几个月前写的了,有点健忘。)
然后得到pkl文件夹就可以修改训练代码了。
[深度学习 - 实战项目] 以图搜图Resnet+LSH-特征编码/图像检索/相似度计算_第1张图片
[深度学习 - 实战项目] 以图搜图Resnet+LSH-特征编码/图像检索/相似度计算_第2张图片
修改cirtorch/datasets/traindataset.py 里面的信息,如下图,对应的pkl文件和图片文件夹准备好。
[深度学习 - 实战项目] 以图搜图Resnet+LSH-特征编码/图像检索/相似度计算_第3张图片
接着就可以通过cirtorch/examples/train.py 进行训练了。具体训练的配置信息,依据你自己的训练任务来。

import os
import pickle
import numpy as np

# 将similar_pics 转换成pkl标签文档,供ImageRetrieval训练数据
if __name__ == '__main__':
    cids = []
    clusters = []
    qidxs = []
    pidxs = []
    class_num = 0
    dirs = 'cirtorch/ImageRetrieval_dataset/train'
    for dir in os.listdir(dirs):
        # test1 = os.listdir(dirs)
        # print(dir)
        # test = os.listdir('/'.join([dirs,dir]))
        one_dir = '/'.join([dirs,dir])

        for qpimg in os.listdir(one_dir):
            # qpdir = '/'.join([one_dir,path])
            save_cid_path = '/'.join([dir,qpimg])
            cids.append(save_cid_path)
            clusters.append(class_num)
        class_num += 1

    print(cids)
    print(len(cids))

    print(clusters)
    print(len(clusters))

    for i in range(len(clusters)-1):
        if clusters[i]==clusters[i+1]:
            qidxs.append(i)
            pidxs.append(i+1)
    qidxs = np.array(qidxs,'uint16')
    pidxs = np.array(pidxs, 'uint16')
    print({'qidxs':qidxs})
    print({'pidxs':pidxs})

    data = {'train':{'cids':cids,'clusters':clusters,'qidxs':qidxs,'pidxs':pidxs}}
    print(data)
    with open("test.pkl", "wb") as f:
        pickle.dump(data, f)

4. 顺便写一下我的部署,我使用的tornado

用tornado主要是再服务器上启动服务,供后端人员提取特征编码。所以就不需要检索那部分了。

这里因为他们要求传入的是base64编码的图片格式,所以输入输出我自己写了一下。
然后项目是先做yoloV5目标检测,然后在提取检测后图案的特征。
Server类:(初始化模型和运行)

from torchvision import transforms
from cirtorch.networks.imageretrievalnet import init_network
from models.experimental import attempt_load

from utils.general import (
    check_img_size, non_max_suppression, scale_coords)
from utils.torch_utils import select_device
from utils.datasets import letterbox

import torch

import numpy as np
from PIL import Image

import cv2
import base64

class Server():
    def __init__(self):
        self.weights, self.imgsz = \
            'weights/yolov5l.pt', 640

        # Initialize
        self.device = select_device('4')

        self.half = self.device.type != 'cpu'  # half precision only supported on CUDA

        # Load model
        self.model = attempt_load(self.weights, map_location=self.device)  # load FP32 model
        imgsz = check_img_size(self.imgsz, s=self.model.stride.max())  # check img_size
        if self.half:
            self.model.half()  # to FP16
        img = torch.zeros((1, 3, imgsz, imgsz), device=self.device)  # init img
        _ = self.model(img.half() if self.half else img) if self.device.type != 'cpu' else None  # run once

        # feanet
        network = 'weights/model_best_adam_epoch404.pth'

        multiscale = '[1]'
        print(">> Loading network:\n>>>> '{}'".format(network))
        state = torch.load(network)

        net_params = {}
        net_params['architecture'] = state['meta']['architecture']
        net_params['pooling'] = state['meta']['pooling']
        net_params['local_whitening'] = state['meta'].get('local_whitening', False)
        net_params['regional'] = state['meta'].get('regional', False)
        net_params['whitening'] = state['meta'].get('whitening', False)
        net_params['mean'] = state['meta']['mean']
        net_params['std'] = state['meta']['std']
        net_params['pretrained'] = False
        # network initialization
        self.fea_net = init_network(net_params)
        self.fea_net.load_state_dict(state['state_dict'])
        print(">>>> loaded network: ")
        print(self.fea_net.meta_repr())
        # setting up the multi-scale parameters
        ms = list(eval(multiscale))
        if torch.cuda.is_available():
            self.fea_net.to(self.device)

        # set up the transform
        self.normalize = transforms.Normalize(
            mean=self.fea_net.meta['mean'],
            std=self.fea_net.meta['std']
        )
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            self.normalize
        ])
        print('initialize')
    def run(self, base64_str, bounding_box):
        model = self.model  # 加载模型

        img_b64decode = base64.b64decode(base64_str)  # base64解码
        img_array = np.frombuffer(img_b64decode, np.uint8)  # 转换np序列
        image = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)

        if bounding_box is not None:
            image = image[int(bounding_box[1]):int(bounding_box[5]), int(bounding_box[0]):int(bounding_box[4])]

        im0 = image
        # 数据预处理
        # Padded resize
        img = letterbox(image, new_shape=self.imgsz)[0]
        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)
        # Run inference
        img = torch.from_numpy(img).to(self.device)
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference 推理/侦测
        pred = model(img, augment=False)[0]

        # Apply NMS
        pred = non_max_suppression(pred, 0.6, 0.5)
        # Process detections
        data = []
        for i, det in enumerate(pred):  # detections per image
            # print(pred)
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], image.shape).round()

                # Write results
                count = 0
                for *xyxy, conf, cls in reversed(det):
                    # print('xy', xyxy[0], xyxy[1], 'xy2', xyxy[2], xyxy[3])
                    count += 1
                    crop_img = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]):int(xyxy[2])]

                    score = conf

                    cv_image = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))

                    fea_net = self.fea_net.eval()

                    cv_image = self.transform(cv_image)

                    if torch.cuda.is_available():
                        cv_image = cv_image.to(self.device)
                    else:
                        cv_image = cv_image

                    feature = fea_net(cv_image.unsqueeze(0)).cpu().data.squeeze()

                    data.append({'features': feature.tolist(), 'score': score.tolist()})

        return data

启动服务:

import json
import base64
import time

import numpy as np
import tornado.web
import tornado.ioloop
# 调用图搜接口类运行初始化函数。
from tusou_model_server import Server


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.int) or isinstance(obj, np.int64):
            return int(obj)
        elif isinstance(obj, np.float):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

class MainHandler(tornado.web.RequestHandler):

    def prepare(self):
        if self.request.body:
            try:
                print(len(self.request.body))
                json.loads(self.request.body.decode("utf-8"), strict=False)
            except ValueError:
                message = "Unable to parse JSON"
                self.send_error(400, message=message)
        print('prepare')
        self.response = dict()

    def get(self, *args, **kwargs):
        self.write("Not implement Get Function")

    def set_default_headers(self):
        print("setting headers!!!")
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers", "*")
        self.set_header('Access-Control-Allow-Methods', 'POST,OPTIONS')

    # executor = ThreadPoolExecutor(8)
    # @run_on_executor
    def post(self, *args, **kwargs):
        t2 = time.time()
        params = self.request.body.decode('utf-8')
        params = json.loads(params, strict=False)

        base64_str = params["img_base64"] # 加载参数
        try:
            bounding_box = params["bounding_box"]  # 加载参数
            # bounding_box = list(map(int, bounding_box))
        except:
            bounding_box = None
        # 传入参数,我这里包括可裁剪和不裁剪
        data = server.run(base64_str,bounding_box)
        output = {"stateCode":"0", "stateInfo":"成功", "data":data}
        test = json.dumps(output, cls=NpEncoder, ensure_ascii=False)

        self.write(test)


    def options(self, *args, **kwargs):
        self.finish()


def make_app():
    return tornado.web.Application([
        (r"/tusou_getfeature", MainHandler),
    ])


def main(port):
    app = make_app()
    app.listen(port)
    tornado.ioloop.IOLoop.current().start()

from tornado import  options
options.define("port", default=*, type=int, help="服务器监听端口号")
options.define("process_num", default=1, type=int, help="启动进程数")
if __name__ == '__main__':

    server = Server()
    options.parse_command_line()
    port = options.options.port
    print("start port at: %s",port)
    main(port)

测试的时候:

import json
import requests
import base64
import time
if __name__ == '__main__':
    detect_url = '.../tusou_getfeature'  # tusou_getfeature

    # 传入裁剪框坐标x1y1 x2y2 x3y3 x4y4 左上角开始顺时针
    image_path = 'test_temp/test3.jpg'
    # with open(image_path, 'rb') as f:
    #     image = f.read()
    #     image_base64 = str(base64.b64encode(image), encoding='utf-8')
    # data_obj = {'img_base64': image_base64, 'bounding_box': []}

    # 不需要裁剪
    image_path = 'crop_test.jpg'
    with open(image_path, 'rb') as f:
        image = f.read()
        image_base64 = str(base64.b64encode(image), encoding='utf-8')
    data_obj = {'img_base64': image_base64}


    # test
    t0 = time.time()
    r = requests.post(detect_url, json.dumps(data_obj))
    t1 = time.time()
    print('time',t1-t0)
    print(r)
    content = r.json()
    print(len(content['data']))

你可能感兴趣的:(深度学习,1024程序员节,深度学习,人工智能,计算机视觉)