参考代码来源于 http://github.com/yinhaoxs/ImageRetrieval-LSH
入职新公司以后一直在搞项目,没什么时间写博客。
最近一个项目是以图搜图项目,主要用到的技术就是目标检测(yolo)+图像检索(ResNet+LSH)。
目标检测就不用多说了,成熟和现成的代码一抓一大把,主要问题就是在优化提升精度和性能上的摸索。
图像检索的技术也挺多,但是网上的资源相对较少,所以记录一下这段时间用到的一个代码。
最开始直接看到的是这个作者的ImageRetrieval-LSH代码。里面说明文档也比较少,所以记录下我看这个源码的过程。
这个代码包括flask部署和利用LSH提高检索速度都写好了,非常的完善,只要给被搜图片的目录和数据库目录就可以进行检索,模型也是训练好的,数据集用的retrieval-SfM-120K(这个数据集38GB我在官网下不下来,网络带宽不行但是下载了标签(.pkl)文件并解析了一下训练集的类型,后面我会讲一下)。
其中的特征编码模型用的是:https://github.com/filipradenovic/cnnimageretrieval-pytorch
这个以图搜图和人脸识别技术其实很像,可以说是一样。无非就是提取特征,然后进行相似度计算。所以相关的技术有ReID,Arcface,以及我在调研的时候有看到一个素描草图的图像匹配的研究。
这里有相关图像检索Image Retrieval知识资料全集
我只用到里面的编码和检索部分。运行demo.py。直接跑通这部分,然后缺什么库函数去pip install就行了。代码这里,权重包需要科学下载,或者按下面百度云链接
2022.4.30更新,由于近期实在太忙,在公司也不方便发文件,各位找我要权重包的我一回家就忘了。。。所以我把作者的权重包也上传一份到百度云,大家自取即可。链接:https://pan.baidu.com/s/1atZ9fETMlvP45c87JnDbSw
提取码:p1cl
from utils.retrieval_feature import AntiFraudFeatureDataset
from utils.retrieval_index import EvaluteMap
if __name__ == '__main__':
hash_size = 0
input_dim = 2048
num_hashtables = 1
img_dir = 'ImageRetrieval/data' #存放所有图像库的图片
test_img_dir = './images' # 待检索的图像
network = './weights/gl18-tl-resnet50-gem-w-83fdc30.pth' # 模型权重
#下面这几个好像没有用,不管他
out_similar_dir = './output/similar'
out_similar_file_dir = './output/similar_file'
all_csv_file = './output/aaa.csv'
feature_dict, lsh = AntiFraudFeatureDataset(img_dir, network).constructfeature(hash_size, input_dim, num_hashtables)
test_feature_dict = AntiFraudFeatureDataset(test_img_dir, network).test_feature()
EvaluteMap(out_similar_dir, out_similar_file_dir, all_csv_file).retrieval_images(test_feature_dict, lsh, 3)
代码首先对img_dir中的所有图片进行特征提取:feature_dict, lsh = AntiFraudFeatureDataset(img_dir, network).constructfeature(hash_size, input_dim, num_hashtables)
返回的feature_dict就是图片特征。(可以直接用余弦相似度进行相似计算)
但是这里还通过LSH对每张图片特征图进行0,1编号,所在这里后面用来图片检索的不是feature_dict,而是lsh,(应该是加速后面图片检索时候的速度)
进到特征编码那块代码retrieval_feature.py
,里面主要对图片进行编码的函数对象是AntiFraudFeatureDataset
首先前面一大段到net.eval()
,都是加载网络模型,可以看到模型选择有很多参数,这些参数对应网络的结构设置,(后面如果用自己的数据对自己的特征编码模型进行训练的话,要根据使用的不同模型参数进行修改)
这个函数ImageProcess
是遍历目录底下的全部图片,并将他们的路径保存在数组中。
然后再这个函数extract_vectors
中提取图像特征。(在这个目录底下ImageRetrieval-LSH/cirtorch/networks/imageretrievalnet.py
)主要也不需要怎么做修改,除非说你要修改一下图片的dataloader(这里是通过将所有图片路径保存下来做的dataset,因为每张图片的尺寸可以不一样,Resnet网络的最后通过一个全连接层输出1 * 2048特征图。)
所以这里出来的vecs
是N张图片的特征编码,每个特征编码是1 * 2048。
然后对待检索图像进行特征编码是在这里test_feature_dict = AntiFraudFeatureDataset(test_img_dir, network).test_feature()
这里他和上面的区别就是没有做LSH,得到的每个特征编码在后面进行检索的时候对进行LSH。(详细对LSH我也没怎么了解,因为我后面部署的时候用的是java进行相似度计算,所以这部分我没怎么了解,好像是Python加速检索速度的,具体这个代码里面用到的模块可以看这里https://github.com/kayzhu/LSHash)
def constructfeature(self, hash_size, input_dim, num_hashtables):
multiscale = '[1]'
print(">> Loading network:\n>>>> '{}'".format(self.network))
state = torch.load(self.network)
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(multiscale))
print(">>>> Evaluating scales: {}".format(ms))
# moving network to gpu and eval mode
if torch.cuda.is_available():
net.cuda()
net.eval()
# set up the transform 数据预处理
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# extract database and query vectors 对图片进行编码提取数据库图片特征
print('>> database images...')
images = ImageProcess(self.img_dir).process()
vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms)
feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
# index
lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables))
for img_path, vec in feature_dict.items():
lsh.index(vec.flatten(), extra_data=img_path)
# ## 保存索引模型
# with open(self.feature_path, "wb") as f:
# pickle.dump(feature_dict, f)
# with open(self.index_path, "wb") as f:
# pickle.dump(lsh, f)
print("extract feature is done")
return feature_dict, lsh
这里图像检索这块我没怎么改动,因为只是测试一下自己训练后的模型的效果比较方便查看用的。所以我只是修改了输出的数量。
这里如果要输出多个Top,要自己多加几个,(也可以自己写个循环,我比较懒,没有写)然后后面我还显示出了得分情况,(因为后面要进行模型的对比)
def find_similar_img_gyz(self, feature_dict, lsh, num_results):
for q_path, q_vec in feature_dict.items():
try:
response = lsh.query(q_vec.flatten(), distance_func="cosine") # , num_results=int(num_results)
# print(response[0][1])
# print(np.rint(100 * (1 - response[0][1])))
query_img_path0 = response[0][0][1]
query_img_path1 = response[1][0][1]
query_img_path2 = response[2][0][1]
query_img_path3 = response[3][0][1]
query_img_path4 = response[4][0][1]
score_img_path0 = response[0][1]
score_img_path1 = response[1][1]
score_img_path2 = response[2][1]
score_img_path3 = response[3][1]
score_img_path4 = response[4][1]
# score0 = response[0][1]
# score0 = np.rint(100 * (1 - score0))
print('**********************************************')
print('input img: {}'.format(q_path))
print('query0 img: {}'.format(query_img_path0),
' score:{}'.format(np.rint(100 * (1 - score_img_path0))))
print('query1 img: {}'.format(query_img_path1),
' score:{}'.format(np.rint(100 * (1 - score_img_path1))))
print('query2 img: {}'.format(query_img_path2),
' score:{}'.format(np.rint(100 * (1 - score_img_path2))))
print('query3 img: {}'.format(query_img_path3),
' score:{}'.format(np.rint(100 * (1 - score_img_path3))))
print('query4 img: {}'.format(query_img_path4),
' score:{}'.format(np.rint(100 * (1 - score_img_path4))))
except:
continue
特征编码模型主要来源于这里:https://github.com/filipradenovic/cnnimageretrieval-pytorch所以训练自己的数据也是根据这里面的代码。主要模块就是cirtorch/example/train.py
首先就是修改一下你要使用的模型参数,我这里用的resnet50,损失用的contrastive(因为用tripletLoss的时候结果出了点问题哈哈哈哈正常来讲应该tripletLoss会更好吧)
# network architecture and initialization options
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet101)')
parser.add_argument('--pool', '-p', metavar='POOL', default='gem', choices=pool_names,
help='pooling options: ' +
' | '.join(pool_names) +
' (default: gem)')
parser.add_argument('--local-whitening', '-lw', dest='local_whitening', action='store_true',
help='train model with learnable local whitening (linear layer) before the pooling')
parser.add_argument('--regional', '-r', dest='regional', action='store_true',
help='train model with regional pooling using fixed grid')
parser.add_argument('--whitening', '-w', dest='whitening', action='store_true',
help='train model with learnable whitening (linear layer) after the pooling')
parser.add_argument('--not-pretrained', dest='pretrained', action='store_false',
help='initialize model with random weights (default: pretrained on imagenet)')
parser.add_argument('--loss', '-l', metavar='LOSS', default='contrastive',
choices=loss_names,
help='training loss options: ' +
' | '.join(loss_names) +
' (default: contrastive)')
parser.add_argument('--loss-margin', '-lm', metavar='LM', default=0.7, type=float,
help='loss margin: (default: 0.7)')
这里训练数据用的是retrieval-SfM-120k
但是因为数据集38个GB,网速不行,在外网上下不下来,所以气急败坏的我直接看他的标签文件retrieval-SfM-120k.pkl
。
这个文件就是一个字典格式文件,大概分了几层如下,因为我没有准备验证集和测试集,所以训练时候测试和验证那部分我直接删去了(主要因为测试集的格式和训练集不一样,我懒得再去解析另一个数据集的格式)。
{ train : {
cids : [ ], cluster : [ ], qidxs : [ ], pidxs : [ ]
},
val : {…}
}
① cids:主要用来存放所有图片的路径,所以不管你图片存放在哪,只要有图片路径即可。数组长度就是总的图片数量。
② cluster:这个是存放该图片的类别,数组的长度和cids一样,类别一一对应cids的图片(retrieval-SfM-120k 是有713个建筑物所以是713类,依据自己的数据集而定,我的数据集每对图片都是一个类,所以有几千个类别)。
③ qidxs&pidxs:这个qidx是存放查询的query图片,对应位置的pidx是存放和他匹配的positive图片,这样就形成了一对正样例。然后这两个数组存放的是前面cids的索引index,对应的是cids[qidxs[1]] -> cids[pidxs[1]]。
所以自己准备一个自己数据集的pkl文件,就可以训练了。
更新一下制作标签的代码(在这一节最后,自己写的,所以有的粗糙。)
把你准备训练的数据放在一个目录下,例如:dirs = 'cirtorch/ImageRetrieval_dataset/train'
如下图。train下每个文件夹都是相似图片的集合。(我自己每个相似图只有两张,所以每个qp文件夹下只有两张,这份代码应该可以支持多张相似图,几个月前写的了,有点健忘。)
然后得到pkl文件夹就可以修改训练代码了。
修改cirtorch/datasets/traindataset.py
里面的信息,如下图,对应的pkl文件和图片文件夹准备好。
接着就可以通过cirtorch/examples/train.py
进行训练了。具体训练的配置信息,依据你自己的训练任务来。
import os
import pickle
import numpy as np
# 将similar_pics 转换成pkl标签文档,供ImageRetrieval训练数据
if __name__ == '__main__':
cids = []
clusters = []
qidxs = []
pidxs = []
class_num = 0
dirs = 'cirtorch/ImageRetrieval_dataset/train'
for dir in os.listdir(dirs):
# test1 = os.listdir(dirs)
# print(dir)
# test = os.listdir('/'.join([dirs,dir]))
one_dir = '/'.join([dirs,dir])
for qpimg in os.listdir(one_dir):
# qpdir = '/'.join([one_dir,path])
save_cid_path = '/'.join([dir,qpimg])
cids.append(save_cid_path)
clusters.append(class_num)
class_num += 1
print(cids)
print(len(cids))
print(clusters)
print(len(clusters))
for i in range(len(clusters)-1):
if clusters[i]==clusters[i+1]:
qidxs.append(i)
pidxs.append(i+1)
qidxs = np.array(qidxs,'uint16')
pidxs = np.array(pidxs, 'uint16')
print({'qidxs':qidxs})
print({'pidxs':pidxs})
data = {'train':{'cids':cids,'clusters':clusters,'qidxs':qidxs,'pidxs':pidxs}}
print(data)
with open("test.pkl", "wb") as f:
pickle.dump(data, f)
用tornado主要是再服务器上启动服务,供后端人员提取特征编码。所以就不需要检索那部分了。
这里因为他们要求传入的是base64编码的图片格式,所以输入输出我自己写了一下。
然后项目是先做yoloV5目标检测,然后在提取检测后图案的特征。
Server类:(初始化模型和运行)
from torchvision import transforms
from cirtorch.networks.imageretrievalnet import init_network
from models.experimental import attempt_load
from utils.general import (
check_img_size, non_max_suppression, scale_coords)
from utils.torch_utils import select_device
from utils.datasets import letterbox
import torch
import numpy as np
from PIL import Image
import cv2
import base64
class Server():
def __init__(self):
self.weights, self.imgsz = \
'weights/yolov5l.pt', 640
# Initialize
self.device = select_device('4')
self.half = self.device.type != 'cpu' # half precision only supported on CUDA
# Load model
self.model = attempt_load(self.weights, map_location=self.device) # load FP32 model
imgsz = check_img_size(self.imgsz, s=self.model.stride.max()) # check img_size
if self.half:
self.model.half() # to FP16
img = torch.zeros((1, 3, imgsz, imgsz), device=self.device) # init img
_ = self.model(img.half() if self.half else img) if self.device.type != 'cpu' else None # run once
# feanet
network = 'weights/model_best_adam_epoch404.pth'
multiscale = '[1]'
print(">> Loading network:\n>>>> '{}'".format(network))
state = torch.load(network)
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
self.fea_net = init_network(net_params)
self.fea_net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(self.fea_net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(multiscale))
if torch.cuda.is_available():
self.fea_net.to(self.device)
# set up the transform
self.normalize = transforms.Normalize(
mean=self.fea_net.meta['mean'],
std=self.fea_net.meta['std']
)
self.transform = transforms.Compose([
transforms.ToTensor(),
self.normalize
])
print('initialize')
def run(self, base64_str, bounding_box):
model = self.model # 加载模型
img_b64decode = base64.b64decode(base64_str) # base64解码
img_array = np.frombuffer(img_b64decode, np.uint8) # 转换np序列
image = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)
if bounding_box is not None:
image = image[int(bounding_box[1]):int(bounding_box[5]), int(bounding_box[0]):int(bounding_box[4])]
im0 = image
# 数据预处理
# Padded resize
img = letterbox(image, new_shape=self.imgsz)[0]
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
# Run inference
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference 推理/侦测
pred = model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred, 0.6, 0.5)
# Process detections
data = []
for i, det in enumerate(pred): # detections per image
# print(pred)
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], image.shape).round()
# Write results
count = 0
for *xyxy, conf, cls in reversed(det):
# print('xy', xyxy[0], xyxy[1], 'xy2', xyxy[2], xyxy[3])
count += 1
crop_img = im0[int(xyxy[1]):int(xyxy[3]), int(xyxy[0]):int(xyxy[2])]
score = conf
cv_image = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
fea_net = self.fea_net.eval()
cv_image = self.transform(cv_image)
if torch.cuda.is_available():
cv_image = cv_image.to(self.device)
else:
cv_image = cv_image
feature = fea_net(cv_image.unsqueeze(0)).cpu().data.squeeze()
data.append({'features': feature.tolist(), 'score': score.tolist()})
return data
启动服务:
import json
import base64
import time
import numpy as np
import tornado.web
import tornado.ioloop
# 调用图搜接口类运行初始化函数。
from tusou_model_server import Server
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.int) or isinstance(obj, np.int64):
return int(obj)
elif isinstance(obj, np.float):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NpEncoder, self).default(obj)
class MainHandler(tornado.web.RequestHandler):
def prepare(self):
if self.request.body:
try:
print(len(self.request.body))
json.loads(self.request.body.decode("utf-8"), strict=False)
except ValueError:
message = "Unable to parse JSON"
self.send_error(400, message=message)
print('prepare')
self.response = dict()
def get(self, *args, **kwargs):
self.write("Not implement Get Function")
def set_default_headers(self):
print("setting headers!!!")
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header("Access-Control-Allow-Headers", "*")
self.set_header('Access-Control-Allow-Methods', 'POST,OPTIONS')
# executor = ThreadPoolExecutor(8)
# @run_on_executor
def post(self, *args, **kwargs):
t2 = time.time()
params = self.request.body.decode('utf-8')
params = json.loads(params, strict=False)
base64_str = params["img_base64"] # 加载参数
try:
bounding_box = params["bounding_box"] # 加载参数
# bounding_box = list(map(int, bounding_box))
except:
bounding_box = None
# 传入参数,我这里包括可裁剪和不裁剪
data = server.run(base64_str,bounding_box)
output = {"stateCode":"0", "stateInfo":"成功", "data":data}
test = json.dumps(output, cls=NpEncoder, ensure_ascii=False)
self.write(test)
def options(self, *args, **kwargs):
self.finish()
def make_app():
return tornado.web.Application([
(r"/tusou_getfeature", MainHandler),
])
def main(port):
app = make_app()
app.listen(port)
tornado.ioloop.IOLoop.current().start()
from tornado import options
options.define("port", default=*, type=int, help="服务器监听端口号")
options.define("process_num", default=1, type=int, help="启动进程数")
if __name__ == '__main__':
server = Server()
options.parse_command_line()
port = options.options.port
print("start port at: %s",port)
main(port)
测试的时候:
import json
import requests
import base64
import time
if __name__ == '__main__':
detect_url = '.../tusou_getfeature' # tusou_getfeature
# 传入裁剪框坐标x1y1 x2y2 x3y3 x4y4 左上角开始顺时针
image_path = 'test_temp/test3.jpg'
# with open(image_path, 'rb') as f:
# image = f.read()
# image_base64 = str(base64.b64encode(image), encoding='utf-8')
# data_obj = {'img_base64': image_base64, 'bounding_box': []}
# 不需要裁剪
image_path = 'crop_test.jpg'
with open(image_path, 'rb') as f:
image = f.read()
image_base64 = str(base64.b64encode(image), encoding='utf-8')
data_obj = {'img_base64': image_base64}
# test
t0 = time.time()
r = requests.post(detect_url, json.dumps(data_obj))
t1 = time.time()
print('time',t1-t0)
print(r)
content = r.json()
print(len(content['data']))