langchain embedding 自定义模型(bge)实现

模型下载

国内 huggingface 镜像

https://hf-mirror.com/

# 安装依赖
pip install -U huggingface_hub
# 设置环境变量 linux/mac
export HF_ENDPOINT=https://hf-mirror.com
# 设置环境变量 windows powershell
$env:HF_ENDPOINT = "https://hf-mirror.com"
# 下载模型
huggingface-cli download --resume-download BAAI/bge-m3 --local-dir ~/Documents/models/BAAI/bge_m3
huggingface-cli download --resume-download BAAI/bge-reranker-v2-m3 --local-dir ~/Documents/models/BAAI/bge_reranker_v2_m3

示例实现

from typing import List
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch


# ======================
# 自定义 Embedding 类
# ======================
class CustomTransformerEmbeddings(Embeddings):
    """自定义 Transformers 嵌入模型"""

    def __init__(self, model_name: str, device: str = "cuda:0"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.model.eval()

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """文档嵌入"""
        return self._embed(texts)

    def embed_query(self, text: str) -> List[float]:
        """查询嵌入"""
        return self._embed([text])[0]

    def _embed(self, texts: List[str]) -> List[List[float]]:
        """实际嵌入实现"""
        inputs = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = self._mean_pooling(outputs, inputs['attention_mask'])

        return embeddings.cpu().numpy().tolist()

    def _mean_pooling(self, model_output, attention_mask):
        """池化方法"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = (
            attention_mask
            .unsqueeze(-1)
            .expand(token_embeddings.size())
            .float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# ======================
# 自定义 Rerank 类
# ======================
class CrossEncoderReranker:
    """基于交叉编码器的重排序器"""

    def __init__(self, model_name: str, device: str = "cuda:0"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
        self.model.eval()

    def rerank(self, query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
        """重排序文档"""
        pairs = [(query, doc.page_content) for doc in documents]
        scores = self._predict(pairs)

        # 组合文档与分数
        scored_docs = list(zip(documents, scores))
        # 按分数降序排序
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        # 返回前k个文档
        return [doc for doc, score in scored_docs[:top_k]]

    def _predict(self, pairs: List[tuple]) -> List[float]:
        """计算相关性分数"""
        inputs = self.tokenizer(
            pairs,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            scores = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()

        return scores.tolist()


# ======================
# 使用示例
# ======================
if __name__ == "__main__":
    # 初始化组件
    embedding_model = CustomTransformerEmbeddings("~/Documents/models/BAAI/bge_m3", "cpu")
    reranker = CrossEncoderReranker("~/Documents/models/BAAI/bge_reranker_v2_m3", "cpu")

    # 示例文档
    documents = [
        Document(page_content="LangChain is a framework for developing LLM applications"),
        Document(page_content="Transformers provides state-of-the-art NLP models"),
        Document(page_content="FAISS is a library for efficient similarity search"),
        Document(page_content="FAISS is a library for efficient similarity search"),
    ]

    # 创建 FAISS 向量库
    vector_store = FAISS.from_documents(
        documents=documents,
        embedding=embedding_model,
        # normalize_L2=True  # 建议对相似性搜索进行归一化
    )

    # 检索流程
    query = "What is LangChain?"

    # 第一步:向量检索
    retrieved_docs = vector_store.similarity_search(query, k=10)

    # 第二步:重排序
    reranked_docs = reranker.rerank(query, retrieved_docs, top_k=3)

    # 输出结果
    print("### 重排序结果 ###")
    for i, doc in enumerate(reranked_docs):
        print(f"Rank {i + 1}: {doc.page_content}...")

关键实现细节说明:

  1. Embedding 类优化
  • 使用 mean pooling 处理变长文本
  • 支持批量处理提升效率
  • 自动处理设备分配(CPU/GPU)
  • 兼容 LangChain 的接口标准
  1. Reranker 类特点
  • 使用交叉编码器进行精准相关性计算
  • 支持动态 top-k 截取
  • 保留原始文档元数据
  • 兼容 LangChain 的文档格式
  1. 最佳实践
  • 向量存储时进行 L2 归一化(normalize_L2=True)
  • 使用小批量推理提升 GPU 利用率
  • 分离检索和重排序阶段
  • 使用适合任务的模型:
    • 检索:双编码器(如 BGE、GTE)
    • 重排序:交叉编码器(如 MS-Marco 系列)
  1. 扩展建议
  • 添加缓存机制提升性能
  • 支持混合搜索(关键词+向量)
  • 添加异步处理支持
  • 实现结果解释功能

典型工作流程:

原始查询 → 向量检索(FAISS) → 初步召回 → 重排序 → 最终结果

性能优化技巧:

  • 对长文本使用动态分块策略
  • 使用量化技术加速推理
  • 对高频查询添加缓存层
  • 使用 ONNX 或 TensorRT 加速模型

模型选择建议:

  • Embedding:BAAI/bge系列、thenlper/gte系列
  • Reranker:BAAI/bge-reranker系列、cross-encoder/ms-marco系列

该实现方案在保持 LangChain 兼容性的同时,提供了灵活的自定义能力,可以方便地替换不同的 Transformer 模型,适应各种业务场景需求。

你可能感兴趣的:(embedding,langchain,transformers,langchain,embedding)