使用FastAPI部署bge-base和bge-reranker

最近在做RAG项目,会频繁使用到本地embedding模型和rerank模型,但是每次跑demo都要用10来秒加载模型,非常慢,所以就封装了接口用于直接调用

import os
import numpy as np
import logging
import uvicorn
import datetime
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from pydantic import Field, BaseModel, validator
from typing import Optional, List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'


def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')


class QADocs(BaseModel):
    query: Optional[str]
    documents: Optional[List[str]]


class Query(BaseModel):
    query: Optional[str]


class Documents(BaseModel):
    documents: Optional[List[str]]


class Singleton(type):
    def __call__(cls, *args, **kwargs):
        if not hasattr(cls, '_instance'):
            cls._instance = super().__call__(*args, **kwargs)
        return cls._instance


class ReRanker(metaclass=Singleton):
    def __init__(self, model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.eval()
        # self.reranker = FlagReranker(model_path, use_fp16=False)

    def compute_score(self, pairs: List[tuple[str, str]]):
        if len(pairs) > 0:
            with torch.no_grad():
                inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
                scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float().tolist()
                print(scores)
            return scores
        else:
            return None


class Embeddings(metaclass=Singleton):
    def __init__(self, model_path):
        model_kwargs = {'device': get_device()}
        encode_kwargs = {'normalize_embeddings': True}  # set True to compute cosine similarity
        model = HuggingFaceBgeEmbeddings(
            model_name=model_path,
            model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs,
        )
        self.embeddings = model

    def encode(self, texts: List[str]):
        return self.embeddings.embed_documents(texts)

    def encode_query(self, text: str):
        return self.embeddings.embed_query(text)


class Chat(object):
    def __init__(self, rerank_model_path: str = "BAAI/bge-reranker-v2-m3", embedding_model_path: str = "BAAI/bge-base-en"):
        self.reranker = ReRanker(rerank_model_path)
        self.embeddings = Embeddings(embedding_model_path)

    def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
        if query_docs is None or len(query_docs.documents) == 0:
            return []

        pair = [(query_docs.query, doc) for doc in query_docs.documents]
        scores = self.reranker.compute_score(pair)

        new_docs = []
        for index, score in enumerate(scores):
            new_docs.append({"index": index, "text": query_docs.documents[index], "score": score})
        results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
        return results

    def embed_query(self, query: str):
        return self.embeddings.encode_query(query)

    def embed_docs(self, docs: List[str]):
        return self.embeddings.encode(docs)


@app.post('/v1/embed_query')
async def handle_embed_query(query: Query, credentials: HTTPAuthorizationCredentials = Security(security)):
    token = credentials.credentials
    if env_bearer_token is not None and token != env_bearer_token:
        raise HTTPException(status_code=401, detail="Invalid token")
    chat = Chat()
    try:
        query_embedding = chat.embed_query(query.query)
        return {"result": query_embedding}
    except Exception as e:
        print(f"报错:\n{e}")
        return {"error": "embed出错"}


@app.post('/v1/embed_docs')
async def handle_embed_docs(docs: Documents, credentials: HTTPAuthorizationCredentials = Security(security)):
    token = credentials.credentials
    if env_bearer_token is not None and token != env_bearer_token:
        raise HTTPException(status_code=401, detail="Invalid token")
    chat = Chat()
    try:
        result = chat.embed_docs(docs.documents)
        return {"result": result}
    except Exception as e:
        print(f"报错:\n{e}")
        return {"error": "embed出错"}


@app.post('/v1/rerank')
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
    token = credentials.credentials
    if env_bearer_token is not None and token != env_bearer_token:
        raise HTTPException(status_code=401, detail="Invalid token")
    chat = Chat()
    try:
        results = chat.fit_query_answer_rerank(docs)
        return {"results": results}
    except Exception as e:
        print(f"报错:\n{e}")
        return {"error": "重排出错"}


if __name__ == "__main__":
    token = os.getenv("ACCESS_TOKEN")
    if token is not None:
        env_bearer_token = token
    try:
        uvicorn.run(app, host='0.0.0.0', port=5000)
    except Exception as e:
        print(f"API启动失败!\n报错:\n{e}")

本代码参考FastGPT项目:https://github.com/labring/FastGPT/blob/main/python/bge-rerank/bge-reranker-v2-m3/app.py

你可能感兴趣的:(fastapi,python,开发语言,RAG,rerank)