最近在做RAG项目,会频繁使用到本地embedding模型和rerank模型,但是每次跑demo都要用10来秒加载模型,非常慢,所以就封装了接口用于直接调用
import os
import numpy as np
import logging
import uvicorn
import datetime
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from pydantic import Field, BaseModel, validator
from typing import Optional, List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'
def get_device():
if torch.cuda.is_available():
return torch.device('cuda')
elif torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
class QADocs(BaseModel):
query: Optional[str]
documents: Optional[List[str]]
class Query(BaseModel):
query: Optional[str]
class Documents(BaseModel):
documents: Optional[List[str]]
class Singleton(type):
def __call__(cls, *args, **kwargs):
if not hasattr(cls, '_instance'):
cls._instance = super().__call__(*args, **kwargs)
return cls._instance
class ReRanker(metaclass=Singleton):
def __init__(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.eval()
# self.reranker = FlagReranker(model_path, use_fp16=False)
def compute_score(self, pairs: List[tuple[str, str]]):
if len(pairs) > 0:
with torch.no_grad():
inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float().tolist()
print(scores)
return scores
else:
return None
class Embeddings(metaclass=Singleton):
def __init__(self, model_path):
model_kwargs = {'device': get_device()}
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
model = HuggingFaceBgeEmbeddings(
model_name=model_path,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
self.embeddings = model
def encode(self, texts: List[str]):
return self.embeddings.embed_documents(texts)
def encode_query(self, text: str):
return self.embeddings.embed_query(text)
class Chat(object):
def __init__(self, rerank_model_path: str = "BAAI/bge-reranker-v2-m3", embedding_model_path: str = "BAAI/bge-base-en"):
self.reranker = ReRanker(rerank_model_path)
self.embeddings = Embeddings(embedding_model_path)
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
if query_docs is None or len(query_docs.documents) == 0:
return []
pair = [(query_docs.query, doc) for doc in query_docs.documents]
scores = self.reranker.compute_score(pair)
new_docs = []
for index, score in enumerate(scores):
new_docs.append({"index": index, "text": query_docs.documents[index], "score": score})
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
return results
def embed_query(self, query: str):
return self.embeddings.encode_query(query)
def embed_docs(self, docs: List[str]):
return self.embeddings.encode(docs)
@app.post('/v1/embed_query')
async def handle_embed_query(query: Query, credentials: HTTPAuthorizationCredentials = Security(security)):
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
chat = Chat()
try:
query_embedding = chat.embed_query(query.query)
return {"result": query_embedding}
except Exception as e:
print(f"报错:\n{e}")
return {"error": "embed出错"}
@app.post('/v1/embed_docs')
async def handle_embed_docs(docs: Documents, credentials: HTTPAuthorizationCredentials = Security(security)):
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
chat = Chat()
try:
result = chat.embed_docs(docs.documents)
return {"result": result}
except Exception as e:
print(f"报错:\n{e}")
return {"error": "embed出错"}
@app.post('/v1/rerank')
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
chat = Chat()
try:
results = chat.fit_query_answer_rerank(docs)
return {"results": results}
except Exception as e:
print(f"报错:\n{e}")
return {"error": "重排出错"}
if __name__ == "__main__":
token = os.getenv("ACCESS_TOKEN")
if token is not None:
env_bearer_token = token
try:
uvicorn.run(app, host='0.0.0.0', port=5000)
except Exception as e:
print(f"API启动失败!\n报错:\n{e}")
本代码参考FastGPT项目:https://github.com/labring/FastGPT/blob/main/python/bge-rerank/bge-reranker-v2-m3/app.py