LlamaIndex 使用 RouterOutputAgentWorkflow

LlamaIndex 中提供了一个 RouterOutputAgentWorkflow 功能,可以集成多个 QueryTool,根据用户的输入判断使用那个 QueryEngine,在做查询的时候,可以从不同的数据源进行查询,例如确定的数据从数据库查询,如果是语义查询可以从向量数据库进行查询。本文将实现两个搜索引擎,根据不同 Query 使用不同 QueryEngine。

安装 MySQL 依赖

pip install mysql-connector-python  

搜索引擎

定义搜索引擎,初始两个数据源

  • 使用 MySQL 作为数据库的数据源
  • 使用 VectorIndex 作为语义搜索数据源
from pathlib import Path
from llama_index.core.tools import QueryEngineTool
from llama_index.core import VectorStoreIndex
import llm
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import Settings
from llama_index.core import SQLDatabase

from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select
Settings.llm = llm.get_ollama("mistral-nemo")
Settings.embed_model = llm.get_ollama_embbeding()

engine = create_engine(
    'mysql+mysqlconnector://root:123456@localhost:13306/db_llama', 
    echo=True  
)

def init_db():
    # 初始化数据库
    metadata_obj = MetaData()

    table_name = "city_stats"
    city_stats_table = Table(
        table_name,
        metadata_obj,
        Column("city_name", String(16), primary_key=True),
        Column("population", Integer, ),
        Column("state", String(16), nullable=False),
    )

    metadata_obj.create_all(engine)

    sql_database = SQLDatabase(engine, include_tables=["city_stats"])
    from sqlalchemy import insert
    rows = [
        {"city_name": "New York City", "population": 8336000, "state": "New York"},
        {"city_name": "Los Angeles", "population": 3822000, "state": "California"},
        {"city_name": "Chicago", "population": 2665000, "state": "Illinois"},
        {"city_name": "Houston", "population": 2303000, "state": "Texas"},
        {"city_name": "Miami", "population": 449514, "state": "Florida"},
        {"city_name": "Seattle", "population": 749256, "state": "Washington"},
    ]
    for row in rows:
        stmt = insert(city_stats_table).values(**row)
        with engine.begin() as connection:
            cursor = connection.execute(stmt)

from llama_index.core.query_engine import NLSQLTableQueryEngine

sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["city_stats"]
)

    
def get_doc_index()-> VectorStoreIndex:
    '''
    解析 words
    '''
    # 创建 OllamaEmbedding 实例,用于指定嵌入模型和服务的基本 URL
    ollama_embedding = llm.get_ollama_embbeding()

    # 读取 "./data" 目录中的数据并加载为文档对象
    documents = SimpleDirectoryReader(input_files=[Path(__file__).parent / "data" / "LA.pdf"]).load_data()


    # 从文档中创建 VectorStoreIndex,并使用 OllamaEmbedding 作为嵌入模型
    vector_index = VectorStoreIndex.from_documents(documents, embed_model=ollama_embedding, 
                                                   transformations=[SentenceSplitter(chunk_size=1000, chunk_overlap=20)],)
    vector_index.set_index_id("vector_index")  # 设置索引 ID
    vector_index.storage_context.persist("./storage")  # 将索引持久化到 "./storage"
    return vector_index

llama_index_query_engine = get_doc_index().as_query_engine()


sql_tool = QueryEngineTool.from_defaults(
    query_engine=sql_query_engine,
    description=(
        "Useful for translating a natural language query into a SQL query over"
        " a table containing: city_stats, containing the population/state of"
        " each city located in the USA."
    ),
    name="sql_tool"
)

llama_cloud_tool = QueryEngineTool.from_defaults(
    query_engine=llama_index_query_engine,
    description=(
        f"Useful for answering semantic questions about certain cities in the US."
    ),
    name="llama_cloud_tool"
)


创建工作流

下图中显示了工作流的节点,绿色背景节点是工作流的动作,例如大模型返回 ToolEvent,ToolEvent 节点执行并返回结果。
LlamaIndex 使用 RouterOutputAgentWorkflow_第1张图片
工作流定义代码:

from typing import Dict, List, Any, Optional

from llama_index.core.tools import BaseTool
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection, LLM
from llama_index.core.workflow import (
    Workflow,
    Event,
    StartEvent,
    StopEvent,
    step,
    Context
)
from llama_index.core.base.response.schema import Response
from llama_index.core.tools import FunctionTool
from llama_index.utils.workflow import draw_all_possible_flows
from llm import get_ollama

from docs import enable_trace

enable_trace()

class InputEvent(Event):
    """Input event."""

class GatherToolsEvent(Event):
    """Gather Tools Event"""

    tool_calls: Any

class ToolCallEvent(Event):
    """Tool Call event"""

    tool_call: ToolSelection

class ToolCallEventResult(Event):
    """Tool call event result."""

    msg: ChatMessage

class RouterOutputAgentWorkflow(Workflow):
    """Custom router output agent workflow."""

    def __init__(self,
        tools: List[BaseTool],
        timeout: Optional[float] = 10.0,
        disable_validation: bool = False,
        verbose: bool = False,
        llm: Optional[LLM] = None,
        chat_history: Optional[List[ChatMessage]] = None,
    ):
        """Constructor."""

        super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose)

        self.tools: List[BaseTool] = tools
        self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools}
        self.llm: LLM = llm
        self.chat_history: List[ChatMessage] = chat_history or []
    

    def reset(self) -> None:
        """Resets Chat History"""

        self.chat_history = []

    @step()
    async def prepare_chat(self, ev: StartEvent) -> InputEvent:
        message = ev.get("message")
        if message is None:
            raise ValueError("'message' field is required.")
        
        # add msg to chat history
        chat_history = self.chat_history
        chat_history.append(ChatMessage(role="user", content=message))
        return InputEvent()

    @step()
    async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent:
        """Appends msg to chat history, then gets tool calls."""

        # Put msg into LLM with tools included
        chat_res = await self.llm.achat_with_tools(
            self.tools,
            chat_history=self.chat_history,
            verbose=self._verbose,
            allow_parallel_tool_calls=True
        )
        tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False)
        
        ai_message = chat_res.message
        self.chat_history.append(ai_message)
        if self._verbose:
            print(f"Chat message: {ai_message.content}")

        # no tool calls, return chat message.
        if not tool_calls:
            return StopEvent(result=ai_message.content)

        return GatherToolsEvent(tool_calls=tool_calls)

    @step(pass_context=True)
    async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
        """Dispatches calls."""

        tool_calls = ev.tool_calls
        await ctx.set("num_tool_calls", len(tool_calls))

        # trigger tool call events
        for tool_call in tool_calls:
            ctx.send_event(ToolCallEvent(tool_call=tool_call))
        
        return None
    
    @step()
    async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
        """Calls tool."""

        tool_call = ev.tool_call

        # get tool ID and function call
        id_ = tool_call.tool_id

        if self._verbose:
            print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")

        # call function and put result into a chat message
        tool = self.tools_dict[tool_call.tool_name]
        output = await tool.acall(**tool_call.tool_kwargs)
        msg = ChatMessage(
            name=tool_call.tool_name,
            content=str(output),
            role="tool",
            additional_kwargs={
                "tool_call_id": id_,
                "name": tool_call.tool_name
            }
        )

        return ToolCallEventResult(msg=msg)
    
    @step(pass_context=True)
    async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None:
        """Gathers tool calls."""
        # wait for all tool call events to finish.
        tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls"))
        if not tool_events:
            return None
        
        for tool_event in tool_events:
            # append tool call chat messages to history
            self.chat_history.append(tool_event.msg)
        
        # # after all tool calls finish, pass input event back, restart agent loop
        return InputEvent()

from muti_agent import sql_tool, llama_cloud_tool
wf = RouterOutputAgentWorkflow(tools=[sql_tool, llama_cloud_tool], verbose=True, timeout=120, llm=get_ollama("mistral-nemo"))

async def main():
    result = await wf.run(message="Which city has the highest population?")
    print("RSULT ===============", result)


# if __name__ == "__main__":
#     import asyncio

#     asyncio.run(main())


import gradio as gr

async def random_response(message, history):
    wf.reset()
    result = await wf.run(message=message)
    print("RSULT ===============", result)
    return result

demo = gr.ChatInterface(random_response, clear_btn=None, title="Qwen2")


demo.launch()

输入问题是 “What are five popular travel spots in Los Angeles?”,自动路由到 VectorIndex 进行查询。
LlamaIndex 使用 RouterOutputAgentWorkflow_第2张图片
输入问题为 “which city has the most population” 时,调用数据库进行搜索。
LlamaIndex 使用 RouterOutputAgentWorkflow_第3张图片

总结

LlamaIndex 中搜索引擎自动路由,根据用户的输入型自动选择所需的搜索引擎,这里有一个需要注意的点,模型需要支持 Function Call。如果 Ollama 本地模型进行推理,不是所有的本地模型都支持Function Call,Llama3.1 和 mistral-nemo 是支持 Function Call 的,可以使用。

你可能感兴趣的:(llamaIndex,LLM,agent)