基于ViT+milvus的以图搜图服务

以图搜图服务简介

服务流程介绍:将图片特征经过vit模型提取特征,保存到milvus库中,并存入对应的唯一id和身份标签,用于相似图片搜索;使用相似图片进行搜索,返回搜索到图片的身份标签和置信度。服务包括图片数据插入和图片相似搜索两部分。

ViT(Vision Transformer)模型

使用huggingface的ViT模型权重。

https://huggingface.co/tttarun/vision_transformer/tree/main    #权重文件地址

基于transformers中vit模型函数编写以图搜图推理函数,返回图片特征。

from transformers import ViTModel, ViTFeatureExtractor

class full_image_vit_feature_extract:
    def __init__(self, device: int, model_path: str):
        self.device = f"cuda:{device}"
        self.model = ViTModel.from_pretrained(model_path)
        self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path)
        self.model.to(self.device)

    def inference(self, image_list):
        inputs = self.feature_extractor(images=image_list, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        features = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # 提取[CLS]标记的嵌入
        return features

Milvus向量库

  • Milvus‌是一款云原生向量数据库,专为存储和检索高维向量数据设计。它具备高可用、高性能、易拓展的特点,特别适用于处理海量向量数据的实时召回。

  • Milvus广泛应用于各种AI领域,包括:
    1.图像检索‌:在图像识别和搜索中,Milvus可以帮助快速找到相似的图片。
    ‌2.自然语言处理(NLP)‌:用于文本相似度搜索和语义分析。
    ‌3.推荐系统‌:通过分析用户行为和偏好,提供个性化的推荐。
    ‌4.异常检测‌:在金融、网络安全等领域,检测异常行为或数据

  • Milvus建库、数据入库和搜索流程

    1.定义MilvusRepository()类函数。

from pymilvus import MilvusClient

class MilvusRepository(object):
    def __init__(self, uri=milvus_uri, token=milvus_token, db_name=milvus_dbname, collection_name=None):
        try:
            self.client = MilvusClient(uri=uri, token=token, db_name=db_name)
        except Exception as ex:
            self.client = None
            raise ex from ex
        self.collection_name = collection_name
        self.metric_type = "COSINE"  # L2
        self.index_type = "IVF_FLAT"
        
    def set_collection(self, collection_name, vector_dim, replica_number=1):
        self.collection_name = collection_name
        if not self.exist_collection():
            self.create_collection(vector_dim)
        self.load_collection(replica_number)
        
    def insert_data(self, data):
        res = self.client.insert(
            collection_name=self.collection_name,
            data=data
        )
        
    def search(self, data, filter="", limit=5, output_fields=None, search_params=None,
               timeout=None, partition_names=None, anns_field=None, **kwargs):
        ret = dict(result=[], cost=0)
        conn = self.client._get_connection()
        try:
            res = conn.search(
                self.collection_name,
                data,
                anns_field or "",
                search_params or {},
                expression=filter,
                limit=limit,
                output_fields=output_fields,
                partition_names=partition_names,
                timeout=timeout,
                **kwargs,
            )
        except Exception as ex:
            raise ex from ex

        for hits in res:
            query_result = []
            for hit in hits:
                query_result.append(hit.to_dict())
            ret['result'].append(query_result)
        ret['cost'] = res.cost
        return ret
        
    def flush(self):
        conn = self.client._get_connection()
        conn.flush([self.collection_name])
    
    @staticmethod
    def uuid(message):
        md5 = hashlib.md5()
        md5.update(str(message).encode('utf-8'))
        return md5.hexdigest()

2.初始化milvus

milvus_client = MilvusRepository(uri="xxx",token="xxx",db_name="xxx")  #网址,帐号密码,数据库名称
vector_dim = 768  #插入数据维度
milvus_client.set_collection('xxx', vector_dim)   # collection_name

3.数据插入,使用numpy随机生成向量模拟数据。

import numpy as np
data = []
    for i in range(10):
        vector = np.random.random(vector_dim)
        data.append(
            {'unique_id': milvus_client.uuid(vector), 'vector': np.random.random(vector_dim),'maker_type': '123abc'}
        )
    milvus_client.insert_data(data)
    milvus_client.flush()

入库数据展示
在这里插入图片描述

4.数据搜索

res = milvus_client.search([np.random.random(vector_dim)], output_fields=['uuid', 'maker_type''vector'])
print(res)

以图搜图服务(部分重要代码)

服务参数配置:

1.vit模型初始化

model_path = os.path.dirname(os.path.realpath(__file__))
extract = full_image_vit_feature_extract(0, "{}/weight/vit-model".format(model_path))

def detect_and_extract(imgs):
    features = extract.inference(imgs)
    idx_map = [i for i in range(len(imgs))]
    return idx_map, features

def full_retrieval(imgs):
    features = extract.inference(imgs)
    if features is not None:
        res = milvus_client.search([features[0]], output_fields=['unique_id', 'maker_type'])
    else:
        res = None
    return res

2.milvus初始化

parser = argparse.ArgumentParser()
parser.add_argument('--milvus_uri', type=str, required=True, help='milvus uri')
parser.add_argument('--milvus_token', type=str, required=True, help='milvus token')
parser.add_argument('--collection_name', type=str, required=True, help='collection name')
parser.add_argument('--milvus_dbname', type=str, required=True, help='milvus_db_name')
opt = parser.parse_args()

milvus_client = MilvusRepository(uri=opt.milvus_uri, token=opt.milvus_token,db_name=opt.milvus_dbname)
vector_dim = 768
milvus_client.set_collection(opt.collection_name, vector_dim)

以图搜图数据入库服务
向milvus库中插入图片及图片的标识。服务接收两个参数,一个是已经转换成base64格式的图片书记,一个是图片的标志type。

@app.route('/img/insert/', methods=['post'])
......
        request_info = request.get_json()
        images = request_info.get('image', [])
        total_data = len(images)
        result["total_data"] = total_data
        types = request_info.get('type', [])
        for img ,tp in zip(images, types):
            img = base64.b64decode(img)
            frame = cv2.imdecode(np.fromstring(img, np.uint8), cv2.IMREAD_COLOR)
            frames.append(frame)
            idxs, img_features  = detect_and_extract([frame])
            if img_features is not None:
                features = extract.inference([frame])
                if len(features) == 0:
                    total_data -= 1
                else:
                    if features is not None: 
                        data.append(
                            {'unique_id': milvus_client.uuid(features[0]), 'vector': features[0], 'maker_type': tp}
                        )
                    else:
                        total_data -= 1
        if data:
            result["insert_data"] = total_data
            milvus_client.insert_data(data)
            milvus_client.flush()

以图搜图搜索服务
传入一张图片,并在milvus库中搜索相似图片。返回相似图片的置信度conf和图片标志type,默认返回相似度最高的前五张图片。

@app.route('/img/full_image/', methods=['post'])
........
for i, img in enumerate(images):
                similar_img_id = []
                similar_img_conf = []
                similar_img_maker_type = []
                img = base64.b64decode(img)
                frame = cv2.imdecode(np.fromstring(img, np.uint8), cv2.IMREAD_COLOR)
            
                idxs, features = detect_and_extract([frame])
                if features is not None:
                    idex_result = full_retrieval([frame])
                    for i_ in idxs:
                        tmp = []
                        for idex_all in idex_result["result"]:  # idex_all is list (first_image_result, ...)
                            tmp = []
                            for idex_s in idex_all:  # idex_s is dict : like this input_images(5)
                                distance = idex_s["distance"]
                                if distance > 0.5:
                                    similar_img_id.append(idex_s["entity"]["unique_id"])
                                    similar_img_conf.append(distance)
                                    maker_type = idex_s["entity"].get("maker_type", "unknown")
                                    similar_img_maker_type.append(maker_type)
                        
                result.append(
                    {
                     "similar_img_conf": similar_img_conf,
                     "similar_img_maker_type": similar_img_maker_type 
                    }
                )
  return json.dumps(dict(status=0, message="success", result=result)

  if __name__ == '__main__':
    app.run(host='0.0.0.0', port=10888, debug=False)

服务测试代码(部分重要代码)

1.以图搜图数据插入测试代码

def https_image_insert_test():
    url = 'http://xx.xx.x.x:10888/img/insert/'
    img_path = 'xxxx'
    imgs = []
    types = []
    with open(img_path, 'rb') as f:
        img = base64.b64encode(f.read())
        img = img.decode('utf-8')
        imgs.append(img)

        types.append("xxx")
    data = {"image": imgs, "type":types}

    start_time = time.time()
    response = requests.post(url, json=data)  # 发送 POST 请求
    end_time = time.time()
    cost_time = end_time - start_time
    rs = response.json()
    insert_rs = rs["result"]
    
    print("==="* 10)
    print(f"insert 1 pic ====>>>{cost_time}s")
    print(f"insert result ====>>> {insert_rs}")      
 

2.以图搜图搜索测试代码

def https_full_image_test():
    url = 'http://xx.xx.xx.x:10888/img/full_image/'
    images = 'xx'
    imgs = []
  
    for image in images:
        with open(image, 'rb') as f:
            img = base64.b64encode(f.read())
            img = img.decode('utf-8')
            imgs.append(img)
    data = {"image": imgs}

    response = requests.post(url, json=data)  # 发送 POST 请求
    print('POST Response:')
    print(response.status_code)  # 打印状态码
    print(response.json())  # 打印返回的 JSON 数据

你可能感兴趣的:(分类算法,pytorch,milvus)