服务流程介绍:将图片特征经过vit模型提取特征,保存到milvus库中,并存入对应的唯一id和身份标签,用于相似图片搜索;使用相似图片进行搜索,返回搜索到图片的身份标签和置信度。服务包括图片数据插入和图片相似搜索两部分。
使用huggingface的ViT模型权重。
https://huggingface.co/tttarun/vision_transformer/tree/main #权重文件地址
基于transformers中vit模型函数编写以图搜图推理函数,返回图片特征。
from transformers import ViTModel, ViTFeatureExtractor
class full_image_vit_feature_extract:
def __init__(self, device: int, model_path: str):
self.device = f"cuda:{device}"
self.model = ViTModel.from_pretrained(model_path)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path)
self.model.to(self.device)
def inference(self, image_list):
inputs = self.feature_extractor(images=image_list, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
features = outputs.last_hidden_state[:, 0, :].cpu().numpy() # 提取[CLS]标记的嵌入
return features
Milvus是一款云原生向量数据库,专为存储和检索高维向量数据设计。它具备高可用、高性能、易拓展的特点,特别适用于处理海量向量数据的实时召回。
Milvus广泛应用于各种AI领域,包括:
1.图像检索:在图像识别和搜索中,Milvus可以帮助快速找到相似的图片。
2.自然语言处理(NLP):用于文本相似度搜索和语义分析。
3.推荐系统:通过分析用户行为和偏好,提供个性化的推荐。
4.异常检测:在金融、网络安全等领域,检测异常行为或数据
Milvus建库、数据入库和搜索流程
1.定义MilvusRepository()类函数。
from pymilvus import MilvusClient
class MilvusRepository(object):
def __init__(self, uri=milvus_uri, token=milvus_token, db_name=milvus_dbname, collection_name=None):
try:
self.client = MilvusClient(uri=uri, token=token, db_name=db_name)
except Exception as ex:
self.client = None
raise ex from ex
self.collection_name = collection_name
self.metric_type = "COSINE" # L2
self.index_type = "IVF_FLAT"
def set_collection(self, collection_name, vector_dim, replica_number=1):
self.collection_name = collection_name
if not self.exist_collection():
self.create_collection(vector_dim)
self.load_collection(replica_number)
def insert_data(self, data):
res = self.client.insert(
collection_name=self.collection_name,
data=data
)
def search(self, data, filter="", limit=5, output_fields=None, search_params=None,
timeout=None, partition_names=None, anns_field=None, **kwargs):
ret = dict(result=[], cost=0)
conn = self.client._get_connection()
try:
res = conn.search(
self.collection_name,
data,
anns_field or "",
search_params or {},
expression=filter,
limit=limit,
output_fields=output_fields,
partition_names=partition_names,
timeout=timeout,
**kwargs,
)
except Exception as ex:
raise ex from ex
for hits in res:
query_result = []
for hit in hits:
query_result.append(hit.to_dict())
ret['result'].append(query_result)
ret['cost'] = res.cost
return ret
def flush(self):
conn = self.client._get_connection()
conn.flush([self.collection_name])
@staticmethod
def uuid(message):
md5 = hashlib.md5()
md5.update(str(message).encode('utf-8'))
return md5.hexdigest()
2.初始化milvus
milvus_client = MilvusRepository(uri="xxx",token="xxx",db_name="xxx") #网址,帐号密码,数据库名称
vector_dim = 768 #插入数据维度
milvus_client.set_collection('xxx', vector_dim) # collection_name
3.数据插入,使用numpy随机生成向量模拟数据。
import numpy as np
data = []
for i in range(10):
vector = np.random.random(vector_dim)
data.append(
{'unique_id': milvus_client.uuid(vector), 'vector': np.random.random(vector_dim),'maker_type': '123abc'}
)
milvus_client.insert_data(data)
milvus_client.flush()
4.数据搜索
res = milvus_client.search([np.random.random(vector_dim)], output_fields=['uuid', 'maker_type','vector'])
print(res)
服务参数配置:
1.vit模型初始化
model_path = os.path.dirname(os.path.realpath(__file__))
extract = full_image_vit_feature_extract(0, "{}/weight/vit-model".format(model_path))
def detect_and_extract(imgs):
features = extract.inference(imgs)
idx_map = [i for i in range(len(imgs))]
return idx_map, features
def full_retrieval(imgs):
features = extract.inference(imgs)
if features is not None:
res = milvus_client.search([features[0]], output_fields=['unique_id', 'maker_type'])
else:
res = None
return res
2.milvus初始化
parser = argparse.ArgumentParser()
parser.add_argument('--milvus_uri', type=str, required=True, help='milvus uri')
parser.add_argument('--milvus_token', type=str, required=True, help='milvus token')
parser.add_argument('--collection_name', type=str, required=True, help='collection name')
parser.add_argument('--milvus_dbname', type=str, required=True, help='milvus_db_name')
opt = parser.parse_args()
milvus_client = MilvusRepository(uri=opt.milvus_uri, token=opt.milvus_token,db_name=opt.milvus_dbname)
vector_dim = 768
milvus_client.set_collection(opt.collection_name, vector_dim)
以图搜图数据入库服务
向milvus库中插入图片及图片的标识。服务接收两个参数,一个是已经转换成base64格式的图片书记,一个是图片的标志type。
@app.route('/img/insert/', methods=['post'])
......
request_info = request.get_json()
images = request_info.get('image', [])
total_data = len(images)
result["total_data"] = total_data
types = request_info.get('type', [])
for img ,tp in zip(images, types):
img = base64.b64decode(img)
frame = cv2.imdecode(np.fromstring(img, np.uint8), cv2.IMREAD_COLOR)
frames.append(frame)
idxs, img_features = detect_and_extract([frame])
if img_features is not None:
features = extract.inference([frame])
if len(features) == 0:
total_data -= 1
else:
if features is not None:
data.append(
{'unique_id': milvus_client.uuid(features[0]), 'vector': features[0], 'maker_type': tp}
)
else:
total_data -= 1
if data:
result["insert_data"] = total_data
milvus_client.insert_data(data)
milvus_client.flush()
以图搜图搜索服务
传入一张图片,并在milvus库中搜索相似图片。返回相似图片的置信度conf和图片标志type,默认返回相似度最高的前五张图片。
@app.route('/img/full_image/', methods=['post'])
........
for i, img in enumerate(images):
similar_img_id = []
similar_img_conf = []
similar_img_maker_type = []
img = base64.b64decode(img)
frame = cv2.imdecode(np.fromstring(img, np.uint8), cv2.IMREAD_COLOR)
idxs, features = detect_and_extract([frame])
if features is not None:
idex_result = full_retrieval([frame])
for i_ in idxs:
tmp = []
for idex_all in idex_result["result"]: # idex_all is list (first_image_result, ...)
tmp = []
for idex_s in idex_all: # idex_s is dict : like this input_images(5)
distance = idex_s["distance"]
if distance > 0.5:
similar_img_id.append(idex_s["entity"]["unique_id"])
similar_img_conf.append(distance)
maker_type = idex_s["entity"].get("maker_type", "unknown")
similar_img_maker_type.append(maker_type)
result.append(
{
"similar_img_conf": similar_img_conf,
"similar_img_maker_type": similar_img_maker_type
}
)
return json.dumps(dict(status=0, message="success", result=result)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=10888, debug=False)
1.以图搜图数据插入测试代码
def https_image_insert_test():
url = 'http://xx.xx.x.x:10888/img/insert/'
img_path = 'xxxx'
imgs = []
types = []
with open(img_path, 'rb') as f:
img = base64.b64encode(f.read())
img = img.decode('utf-8')
imgs.append(img)
types.append("xxx")
data = {"image": imgs, "type":types}
start_time = time.time()
response = requests.post(url, json=data) # 发送 POST 请求
end_time = time.time()
cost_time = end_time - start_time
rs = response.json()
insert_rs = rs["result"]
print("==="* 10)
print(f"insert 1 pic ====>>>{cost_time}s")
print(f"insert result ====>>> {insert_rs}")
2.以图搜图搜索测试代码
def https_full_image_test():
url = 'http://xx.xx.xx.x:10888/img/full_image/'
images = 'xx'
imgs = []
for image in images:
with open(image, 'rb') as f:
img = base64.b64encode(f.read())
img = img.decode('utf-8')
imgs.append(img)
data = {"image": imgs}
response = requests.post(url, json=data) # 发送 POST 请求
print('POST Response:')
print(response.status_code) # 打印状态码
print(response.json()) # 打印返回的 JSON 数据