完整代码:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()
class TravelSQLAgentTool:
"""
A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.
This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.
The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
generate a final answer for the user.
Attributes:
sql_agent_llm (LLAMA): An instance of a LLAMA language model used to generate and process SQL queries.
system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
db (SQLDatabase): An instance of the SQL database used to execute queries.
chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.
Methods:
__init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
"""
def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
"""
Initializes the TravelSQLAgentTool with the necessary configurations.
Args:
llm (str): The name of the language model to be used for generating and interpreting SQL queries.
sqldb_directory (str): The directory path where the SQLite database is stored.
llm_temerature (float): The temperature setting for the language model, controlling response randomness.
"""
# 初始化 Llama 模型,使用 Groq 后端
# "llama-3.3-70b-specdec"
self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)
self.db = SQLDatabase.from_uri(
f"sqlite:///{sqldb_directory}")
# print(self.db.get_usable_table_names())
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Do not Limit {top_k} the results.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
# write_query
write_query = create_sql_query_chain(self.sql_agent_llm, self.db,prompt=custom_prompt)
execute_query = QuerySQLDataBaseTool(db=self.db)
# answer
self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
Question: {question}\n
SQL Query: {query}\n
SQL Result: {result}\n
Answer:
"""
answer_prompt = PromptTemplate.from_template(
self.system_role)
answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
# 8. 定义一个调试链 debug_chain,用于打印 write_query 生成的 SQL 查询。
# 这里使用 RunnablePassthrough 执行一个 lambda 函数:
# lambda data: (print("write_query execution result:", data["query"]), data)[1]
# 解释:先打印 data 字典中 "query" 对应的 SQL 语句,然后将原始 data 返回,以便后续链继续处理。
debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
# 9. 构造完整的处理链 chain_ex:
# - 首先调用 write_query 生成 SQL 查询,并将结果存储到字典的 "query" 字段中;
# - 接着通过 debug_chain 打印出生成的 SQL 查询;
# - 然后使用 execute_query 执行 SQL 查询,结果存入 "result" 字段(此处利用 itemgetter 提取 "query" 字段后传递给执行工具);
# - 最后将问题、SQL 查询以及查询结果传递给 answer 链,生成最终回答。
# chain
chain_ex = (
RunnablePassthrough.assign(query=write_query)
| debug_chain
| RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
| answer
)
# 利用 bind 绑定固定参数到链中
bound_chain = chain_ex.bind(
dialect=self.db.dialect,
table_info=self.db.get_table_info(),
top_k=55
)
self.chain = bound_chain
sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"
@tool
def query_travel_sqldb(query: str) -> str:
"""Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
agent = TravelSQLAgentTool(
llm="llama3-70b-8192", #TOOLS_CFG.travel_sqlagent_llm
sqldb_directory= sqldb_directory,#TOOLS_CFG.travel_sqldb_directory
llm_temerature=0 #TOOLS_CFG.travel_sqlagent_llm_temperature
)
response = agent.chain.invoke({"question": query})
return response
print(query_travel_sqldb(query))
下面我将分步骤、用通俗易懂的语言详细解释这段代码的含义和作用,并举例说明各部分的工作流程。
这段代码的主要目的是构造一个工具(TravelSQLAgentTool),它能利用大语言模型(LLM,例如 Llama 模型)来完成以下任务:
同时,为了调试方便,在生成 SQL 查询后会打印出这条查询语句,这样你就能看到 LLM 生成的 SQL 语句内容。
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()
导入各个模块:
这些模块主要用于创建提示模板、构造数据处理链、连接 SQL 数据库以及注册工具函数。
加载环境变量:
load_dotenv()
用于加载 .env
文件中的环境变量,方便管理敏感信息(如 API 密钥)。
这个类封装了整个查询流程。我们逐行看它的初始化方法 __init__
。
self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)
init_chat_model
初始化一个 LLM 模型,这里传入的 llm
参数(例如 “llama3-70b-8192”)指定使用哪个模型。model_provider="groq"
:表示使用 Groq 后端;temperature=llm_temerature
:温度参数决定了模型回答的随机性(0 表示确定性很高)。self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
"data/Chinook.db"
)构造一个 SQLite 数据库连接实例,后续执行 SQL 查询时会使用它。custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Do not Limit {top_k} the results.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
{dialect}
:SQL 的方言(例如 SQLite)。{table_info}
:数据库中各表的结构信息;{input}
:用户的问题;{top_k}
:限制查询返回记录条数的参数(不过这里实际上是说明“不要限制”)。SELECT COUNT(*) FROM Playlist;
的 SQL 查询。write_query = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)
execute_query = QuerySQLDataBaseTool(db=self.db)
write_query:
SELECT COUNT(*) FROM Playlist;
execute_query:
self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
Question: {question}\n
SQL Query: {query}\n
SQL Result: {result}\n
Answer:
"""
answer_prompt = PromptTemplate.from_template(self.system_role)
answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
system_role:
answer 链:
debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
data["query"]
的内容,即打印出由 write_query 生成的 SQL 查询语句,然后把原始数据继续返回给后续链的步骤。data
为:{"query": "SELECT COUNT(*) FROM Playlist;"}
那么这个 lambda 函数会先打印:write_query execution result: SELECT COUNT(*) FROM Playlist;
然后返回原始数据 {"query": "SELECT COUNT(*) FROM Playlist;"}
,不对数据做任何修改。chain_ex = (
RunnablePassthrough.assign(query=write_query)
| debug_chain
| RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
| answer
)
这段代码构造了一个数据处理流水线,每一步的含义如下:
RunnablePassthrough.assign(query=write_query)
write_query
生成 SQL 查询,并将生成的查询结果存储到数据字典的 "query"
键中。{"query": "SELECT COUNT(*) FROM Playlist;"}
| debug_chain
debug_chain
,打印出 SQL 查询,同时不改变数据。| RunnablePassthrough.assign(result=itemgetter(“query”) | execute_query)
itemgetter("query")
从数据字典中提取 SQL 查询语句,然后将其传递给 execute_query
工具,执行 SQL 查询,并将执行结果存储到数据字典的 "result"
键中。{"query": "SELECT COUNT(*) FROM Playlist;", "result": 25}
| answer
bound_chain = chain_ex.bind(
dialect=self.db.dialect,
table_info=self.db.get_table_info(),
top_k=55
)
self.chain = bound_chain
bind
方法将一些固定参数(如 SQL 的方言、数据库表结构信息、以及 top_k 参数)绑定到流水线中,确保每次调用链时这些参数都自动传递进去。self.chain
,这样后续调用就会按照这个步骤顺序执行。sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"
here
函数定位到数据库文件所在的路径(相对路径 “data/Chinook.db”);@tool
def query_travel_sqldb(query: str) -> str:
"""Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
agent = TravelSQLAgentTool(
llm="llama3-70b-8192", # 指定使用的语言模型
sqldb_directory= sqldb_directory, # 数据库文件路径
llm_temerature=0 # LLM 温度设为 0,表示回答比较确定,不引入随机性
)
response = agent.chain.invoke({"question": query})
return response
@tool
装饰器注册成工具函数,方便外部调用。agent.chain
)的 invoke
方法,把用户的问题(键名为 "question"
)传入整个链进行处理,最终得到回答。query_travel_sqldb("and calculate the number of all Playlist")
时:
SELECT COUNT(*) FROM Playlist;
;print(query_travel_sqldb(query))
这段代码构建了一个基于 LLM 的 SQL 查询代理工具,其工作流程为:
这种设计使得查询流程高度自动化,同时又便于调试和检查中间步骤,帮助理解每一步是如何工作的。