文本到SQL (Text2SQL) 任务旨在将自然语言查询转换为可执行的SQL查询。得益于大规模语言模型 (LLMs) 的应用,该领域取得了显著进展。然而,模型的可扩展性、生成空间的限制以及SQL生成过程中的连贯性问题仍然存在。为了解决这些问题,我们提出了SQL-o1,一种基于自奖励的启发式搜索方法,旨在增强LLMs在SQL查询生成中的推理能力。SQL-o1结合了蒙特卡洛树搜索 (MCTS) 进行过程级搜索,并构建了一个Schema-Aware数据集,以帮助模型更好地理解数据库模式。广泛的实验表明,SQL-o1在复杂的Bird数据集上的执行准确率比最新的基线方法提高了10.8%,甚至超过了基于GPT-4的方法。此外,SQL-o1在少量样本学习场景中表现出色,并且具有强大的跨模型迁移能力。我们的代码已公开可用1。
Text2SQL是指将自然语言问题转换为结构化查询语言 (SQL) 的过程,作为非专业用户使用自然语言与数据库交互的有效方法。该领域的发展可以分为三个阶段:首先,使用预训练模型或抽象语法树对输入序列进行编码和解码 (Wang et al., 2020);其次,采用序列到序列方法 (Xie et al., 2022);最近,大规模语言模型 (LLMs) (Zhang et al., 2025) 被证明是Text2SQL的有效解决方案。然而,准确地将自然语言查询与数据库中的数据对齐仍然是一个重大挑战。
最近关于基于LLMs的Text2SQL的研究主要集中在通过上下文学习提示策略和特定领域的监督微调来提高模型性能。该领域的关键方法包括Schema Linking、Self-correction和Chain-of-Thought (CoT) (Tai et al., 2023),旨在增强模型对模式的理解、提高其推理能力,并帮助生成更准确的SQL查询。
然而,仍存在三个主要挑战:1. 这些方法通常受到模型规模的限制。较小的LLMs在理解复杂指令方面的能力有限,导致在处理复杂任务时泛化能力较差。2. 端到端生成方法受到生成空间的限制。由于缺乏逐步验证和灵活调整的机会,模型在生成过程中难以探索更多潜在路径,限制了输出的多样性和准确性。3. 在SQL生成过程中存在推理过程的连贯性问题。如果在任何步骤中出现错误,通常会影响后续步骤的正确性,导致最终生成的SQL查询无法正确执行。
受Process-supervised Reward Model (Luo et al., 2023) 的启发,我们提出了SQL-o1,一种基于自奖励的启发式搜索方法,如图 1 所示。首先,我们广泛挖掘数据库模式,收集表列字段、代表性实体等信息,构建一个Schema-Aware数据集,用于微调大规模语言模型 (LLMs)。此外,我们引入蒙特卡洛树搜索 (MCTS) (Swiechowski et al. ´ , 2023) 作为推理媒介,利用过程级推理和自奖励减少LLMs生成过程中的逻辑错误。通过扩展生成空间并克服SQL生成的一致性挑战,我们显著增强了LLMs的推理能力。
我们在Bird和Spider数据集及其三个变体上进行了实验。实验结果表明,SQL-o1结合常见的开源模型(如Llama 3 (Touvron et al., 2023) 和Qwen 2.5 (Yang et al., 2024a))显著优于大多数现有方法,甚至超过了其他基于GPT-4的方法。此外,我们在少量样本微调场景中应用SQL-o1,结果表明当样本量达到2000时,几乎所有性能指标都超过了在完整数据集上微调的模型。最后,我们还讨论了SQL-o1的迁移能力和其组件的贡献。我们的贡献可以总结如下:
Text2SQL任务近年来取得了显著进展,主要集中在大规模语言模型 (LLMs) 上,其出色的推理能力为Text2SQL任务提供了新的方向和机会。目前,基于LLMs的方法可以大致分为两类:提示工程和基于代理的LLMs交互。
在LLMs的早期阶段,一种直接有效的方法是精心设计有效的提示,以更好地利用LLMs的潜力,这同样适用于Text2SQL任务。通过Chain of Thought (Zhang et al., 2023) 增强LLMs的推理能力是一个有前景的尝试。一些方法 (Wang et al., 2024; Pourreza and Rafiei, 2023; Li et al., 2024a) 利用模式链接将自然语言问题与数据库模式元素结合,取得了令人满意的结果。其中,DAIL-SQL (Gao et al., 2024) 系统地研究了提示工程在基于LLMs的文本到SQL方法中的应用,包括问题表示、提示组件、示例选择和示例组织。
最近,一些研究将注意力从提示工程(例如GPT-4和其他封闭源模型)转向了LLMs的微调。SENSE (Yang et al., 2024b) 合成强数据,并对弱LLM生成的弱数据进行直接偏好优化 (DPO),而ROUTE (Qin et al., 2024) 提出了一种多任务协同微调方法,减少了SQL生成中的潜在错误,取得了更好的结果。
基于代理的交互方法 (Chen et al., 2024b) 通过设计反馈信号引导LLMs生成准确的SQL查询。早期的工作 (Shi et al., 2022) 专注于基于执行结果改进SQL,通过执行SQL查询并根据执行风险选择最准确的翻译。其他工作 (Chen et al., 2024a; Guo et al., 2023) 利用LLMs检查结果并纠正生成的SQL与真实SQL查询之间的差异。MAC-SQL (Wang et al., 2024) 引入了多代理框架和其他新颖的交互方法 (Xiong et al., 2024)。然而,这些方法大多依赖高质量的外部反馈,这在实际应用中往往不可用,并且主要依赖封闭源LLMs,忽视了开源LLMs在推理方面的潜力。
给定一个Text2SQL数据集 D = {(Di , Qi , Si)} N i=1,其中每个样本包含一个SQL数据库 Di 、一个自然语言问题 Qi 和相应的 ground-truth SQL 查询 Si ,Text2SQL任务的目标是使用大规模语言模型生成一个SQL查询 Qi ′,并确保其执行结果与 Si 匹配。
自奖励启发式动态搜索主要由一系列状态 O = {o0, o1, o2, …, ot−1} 和基于这些状态生成的动作序列 A = {a1, a2, …, at} 组成。每次执行动作 at 时,模型将收到相应的奖励 Rt ∈ R。奖励和动作均由模型 π 生成。
在本节中,我们将介绍SQL-o1的三个组成部分:Schema-Aware数据构建、渐进式SQL生成和自奖励启发式动态搜索。
SQL-o1需要在进行启发式动态搜索之前准确理解数据库结构和查询条件。因此,我们设计了提取表字段类型和样本数据条目的策略,以帮助模型更好地掌握数据库模式,从而优化启发式搜索过程。
列的数据类型决定了字段中可以存储的值以及这些值的处理方式。在构建Text2SQL提示时,指定列的数据类型至关重要,因为不同的数据类型需要不同的处理。例如,数值 (NUMBER) 数据支持加法和平均等数学运算,而文本 (TEXT) 数据通常用于过滤和匹配。这些类型指示符有助于模型正确生成SQL查询。
示例数据库条目是指数据库表中的小部分数据,帮助模型理解数据的内容和结构。在Text2SQL任务中,提示中的示例数据帮助模型将自然语言查询映射到特定的数据库条目。例如,生成查询 “orders.order_date BETWEEN ‘2022-01-01’ AND ‘2022-12-31’” 时,模型需要理解 “order_date” 列的日期格式。同样,对于 “products.category” 字段,模型应识别 “category” 列中的特定值,如 “Electronics” 或 “Clothing”。通过提供代表性示例数据,模型可以更好地理解列的内容和格式,从而更准确地生成SQL查询。
主键和外键定义了数据库表之间的关系。例如,表A中的 “ID” 和 “Type” 列作为主键,而表B中的 “ID” 连接到表A中的 “ID”,形成关系 A.ID = B.ID。这种关系有助于模型理解如何连接表并正确检索数据,是识别表之间的依赖关系和连接条件的关键。
渐进式SQL生成 (PSG) 是监督微调 (SFT) 的一种变体,核心思想是在训练过程中在特定关键字处截断完整的SQL查询,模型的任务是根据提示重建完整的查询。我们主要关注预训练的大规模语言模型中预测错误或复杂语法结构的SQL查询。例如,在查询 ‘SELECT name, age FROM employees WHERE Department = ‘HR’ AND salary > 50000’ 中,截断发生在关键字 ‘WHERE’ 或 ‘AND’ 处,而不是任意位置。如果截断发生在 ‘SELECT name, age FROM employees WHERE’,模型需要从这个片段生成完整的查询。
这种增量生成方法利用了LLMs的连续生成能力,帮助模型更好地理解查询结构和语法,减少生成错误,特别是在处理多个连接或复杂过滤条件时。
基于上述内容,我们为LLMs开发了一个基本的微调数据集,主要包括第 4.1 和 4.2 节的内容。我们表示构建的数据集为:
D s = { σ p ( D i , Q i ) , S i } i = 1 N s , D_s = \{ \sigma_p(\mathcal{D}_i, \mathcal{Q}_i), \mathcal{S}_i \}_{i=1}^{N_s}, Ds={σp(Di,Qi),Si}i=1Ns,
其中 σp 表示我们定义的提示构建函数,Ns 表示数据集中的样本总数。
本节提出的方法结合了强化学习框架、蒙特卡洛树搜索 (MCTS) 和自奖励评估,以指导模型在SQL查询生成过程中的决策。根据算法的组成部分,该方法主要分为:SQL生成规划、自奖励评估和启发式动态搜索。
我们将SQL查询生成任务定义为一个顺序决策任务,模型的目标是根据当前上下文选择下一个SQL片段(如表名、列名或SQL关键字)。这被视为一个策略生成问题,目标是教会模型一种策略,以最大化生成正确SQL查询的可能性:
a t = argmax a t ′ π ( a t ′ ∣ o t − 1 ) . (1) a_t = \underset{a_t'}{\text{argmax}} \,\pi(a_t' \mid o_{t-1}).\tag{1} at=at′argmaxπ(at′∣ot−1).(1)
方程 (1) 描述了策略模型 π 如何根据前一个状态 ot−1(即前一个步骤生成的SQL片段)选择最优动作 at(即第 t 步生成的SQL片段)。具体来说,模型选择一个可能的SQL片段 a ′ t,以最大化概率 π(a ′ t | ot−1)。
该任务的目标是根据当前状态评估生成的SQL查询片段的质量和有效性,提供奖励和反馈信号以指导决策过程。具体来说,我们提出了一种评分函数 Rπ,利用 π 的对数概率值评估给定输入 x 生成输出 y 的可能性:
R π ( y ∣ x ) = β + α log π ( y ∣ x ) , ( 2 ) R_{\pi}(y \mid x) = \beta + \alpha \log \pi(y \mid x), \qquad (2) Rπ(y∣x)=β+αlogπ(y∣x),(2)
其中 β 是定义的满分,设为100,α 是一个正温度值,用于控制分数的差异。
蒙特卡洛树搜索 (MCTS) 是一种强大的决策算法,广泛应用于博弈论(如AlphaGo)和规划问题。如图 2 所示,我们使用MCTS作为启发式搜索方法,以指导SQL查询生成。它逐步探索和生成SQL查询序列,模拟结果,并根据自奖励指导优化搜索路径。
选择。MCTS的选择阶段从根节点开始,遍历子节点直到到达叶节点。每个节点代表SQL查询生成过程中的一个决策点,模型根据方程 (1) 选择下一个有效的SQL标记,逐步生成查询。在关键的语法和语义决策点,模型使用启发式截断扩展部分查询。然后应用UCT算法指导节点选择,平衡未访问查询结构的探索和高奖励路径的利用:
n t = argmax n ∈ N ( o t − 1 ) [ Q ( o t − 1 + n ) + w ⋅ ln N ( o t − 1 ) N ( o t − 1 + n ) ] (3) n_t = \underset{n \in \mathcal{N}(o_{t-1})}{\text{argmax}} \left[ Q(o_{t-1} + n) + w \cdot \frac{\sqrt{\ln N(o_{t-1})}}{N(o_{t-1} + n)} \right] \tag{3} nt=n∈N(ot−1)argmax[Q(ot−1+n)+w⋅N(ot−1+n)lnN(ot−1)](3)
其中 N (.) 表示给定状态 ot−1 的候选扩展路径,Q(.) 表示当前状态的Q值,反映执行动作的预期回报。N(.) 表示代理状态的访问次数。
扩展。选择过程选择最相关的SQL查询作为候选扩展。当最大查询深度 L 未达到时,模型继续通过探索下一个可能的SQL操作或子句来扩展查询:
{ n t ( b ) } b = 1 B ∼ π ( n t ∣ o t − 1 ) B e a m , ( 4 ) \left\{ n_t^{(b)} \right\}_{b=1}^{B} \sim \pi \left( n_t | o_{t-1} \right)_{Beam},\qquad(4) {nt(b)}b=1B∼π(nt∣ot−1)Beam,(4)
其中,π(.)Beam 表示束搜索算法,B 是束宽度。然后,模型根据与前一个查询片段的语义相似性选择最相关的SQL操作进行扩展:
N ( o t − 1 ) = { n t ( i ) } i = 1 d ← argmax d R π ( { n t ( b ) } b = 1 B ∣ o t − 1 ) , ( 5 ) \begin{aligned} \mathcal{N}(o_{t-1}) &= \left\{ n_t^{(i)} \right\}_{i=1}^d \leftarrow \operatorname*{argmax}_d \\ R_\pi \left( \left\{ n_t^{(b)} \right\}_{b=1}^B \mid o_{t-1} \right), \quad (5) \end{aligned} N(ot−1)Rπ({nt(b)}b=1B∣ot−1),(5)={nt(i)}i=1d←dargmax
其中 Rπ 表示评估每个候选扩展质量的奖励函数,d < B。例如,如果当前状态是 “user” 表的部分查询,模型可能会生成 “SELECT user.id” 或 “SELECT user.name” 并根据其与输入问题的语义相关性选择候选。
模拟和回溯。扩展节点后,模型为所有新添加的子节点分配分数,如方程 (2) 和 (6) 所示。根据方程 (7),选择得分最高的节点进行进一步模拟,直到达到最终状态,从而生成完整的SQL查询生成轨迹。
Q ( o l ( n ) ) = δ R π ( n l ∣ o l − 1 ( n ) ) + ( 1 − δ ) R π ( S ∣ Q ) , (6) Q(o_l^{(n)}) = \delta R_\pi(n_l \mid o_{l-1}^{(n)}) + (1 - \delta)R_\pi(\mathcal{S} \mid \mathcal{Q}),\tag{6} Q(ol(n))=δRπ(nl∣ol−1(n))+(1−δ)Rπ(S∣Q),(6)
其中 δ 是一个介于 (0, 1) 之间的参数,用于平衡过程得分和总得分,通常设为0.5。算法通过更新从叶节点到根节点的所有节点的Q值进行回溯。
$$Q(o_t^{(n)}) = \max_{j=1}^n \left( \frac{\sum_{i=l}^t Q(o_i^{(j)})}{l - t### 4.3 自奖励启发式动态搜索
本节提出的方法结合了强化学习框架、蒙特卡洛树搜索 (MCTS) 和自奖励评估,以指导模型在SQL查询生成过程中的决策。根据算法的组成部分,该方法主要分为:SQL生成规划、自奖励评估和启发式动态搜索。
我们将SQL查询生成任务定义为一个顺序决策任务,模型的目标是根据当前上下文选择下一个SQL片段(如表名、列名或SQL关键字)。这被视为一个策略生成问题,目标是教会模型一种策略,以最大化生成正确SQL查询的可能性:
a t = argmax a t ′ π ( a t ′ ∣ o t − 1 ) . (1) a_t = \underset{a_t'}{\text{argmax}} \,\pi(a_t' \mid o_{t-1}).\tag{1} at=at′argmaxπ(at′∣ot−1).(1)
方程 (1) 描述了策略模型 π 如何根据前一个状态 ( o_{t-1} )(即前一个步骤生成的SQL片段)选择最优动作 ( a_t )(即第 t 步生成的SQL片段)。具体来说,模型选择一个可能的SQL片段 ( a_t’ ),以最大化概率 ( \pi(a_t’ \mid o_{t-1}) )。
该任务的目标是根据当前状态评估生成的SQL查询片段的质量和有效性,提供奖励和反馈信号以指导决策过程。具体来说,我们提出了一种评分函数 ( R_\pi ),利用 π 的对数概率值评估给定输入 ( x ) 生成输出 ( y ) 的可能性:
R π ( y ∣ x ) = β + α log π ( y ∣ x ) , ( 2 ) R_{\pi}(y \mid x) = \beta + \alpha \log \pi(y \mid x), \qquad (2) Rπ(y∣x)=β+αlogπ(y∣x),(2)
其中 ( \beta ) 是定义的满分,设为100,( \alpha ) 是一个正温度值,用于控制分数的差异。
蒙特卡洛树搜索 (MCTS) 是一种强大的决策算法,广泛应用于博弈论(如AlphaGo)和规划问题。如图 2 所示,我们使用MCTS作为启发式搜索方法,以指导SQL查询生成。它逐步探索和生成SQL查询序列,模拟结果,并根据自奖励指导优化搜索路径。
选择。MCTS的选择阶段从根节点开始,遍历子节点直到到达叶节点。每个节点代表SQL查询生成过程中的一个决策点,模型根据方程 (1) 选择下一个有效的SQL标记,逐步生成查询。在关键的语法和语义决策点,模型使用启发式截断扩展部分查询。然后应用UCT算法指导节点选择,平衡未访问查询结构的探索和高奖励路径的利用:
n t = argmax n ∈ N ( o t − 1 ) [ Q ( o t − 1 + n ) + w ⋅ ln N ( o t − 1 ) N ( o t − 1 + n ) ] (3) n_t = \underset{n \in \mathcal{N}(o_{t-1})}{\text{argmax}} \left[ Q(o_{t-1} + n) + w \cdot \frac{\sqrt{\ln N(o_{t-1})}}{N(o_{t-1} + n)} \right] \tag{3} nt=n∈N(ot−1)argmax[Q(ot−1+n)+w⋅N(ot−1+n)lnN(ot−1)](3)
其中 ( N(\cdot) ) 表示给定状态 ( o_{t-1} ) 的候选扩展路径,( Q(\cdot) ) 表示当前状态的Q值,反映执行动作的预期回报。( N(\cdot) ) 表示代理状态的访问次数。
扩展。选择过程选择最相关的SQL查询作为候选扩展。当最大查询深度 ( L ) 未达到时,模型继续通过探索下一个可能的SQL操作或子句来扩展查询:
{ n t ( b ) } b = 1 B ∼ π ( n t ∣ o t − 1 ) Beam , ( 4 ) \left\{ n_t^{(b)} \right\}_{b=1}^{B} \sim \pi \left( n_t | o_{t-1} \right)_{\text{Beam}},\qquad(4) {nt(b)}b=1B∼π(nt∣ot−1)Beam,(4)
其中,( \pi(\cdot)_{\text{Beam}} ) 表示束搜索算法,( B ) 是束宽度。然后,模型根据与前一个查询片段的语义相似性选择最相关的SQL操作进行扩展:
N ( o t − 1 ) = { n t ( i ) } i = 1 d ← argmax d R π ( { n t ( b ) } b = 1 B ∣ o t − 1 ) , ( 5 ) \begin{aligned} \mathcal{N}(o_{t-1}) &= \left\{ n_t^{(i)} \right\}_{i=1}^d \leftarrow \operatorname*{argmax}_d \\ R_\pi \left( \left\{ n_t^{(b)} \right\}_{b=1}^B \mid o_{t-1} \right), \quad (5) \end{aligned} N(ot−1)Rπ({nt(b)}b=1B∣ot−1),(5)={nt(i)}i=1d←dargmax
其中 ( R_\pi ) 表示评估每个候选扩展质量的奖励函数,( d < B )。例如,如果当前状态是 “user” 表的部分查询,模型可能会生成 “SELECT user.id” 或 “SELECT user.name” 并根据其与输入问题的语义相关性选择候选。
模拟和回溯。扩展节点后,模型为所有新添加的子节点分配分数,如方程 (2) 和 (6) 所示。根据方程 (7),选择得分最高的节点进行进一步模拟,直到达到最终状态,从而生成完整的SQL查询生成轨迹。
Q ( o l ( n ) ) = δ R π ( n l ∣ o l − 1 ( n ) ) + ( 1 − δ ) R π ( S ∣ Q ) , (6) Q(o_l^{(n)}) = \delta R_\pi(n_l \mid o_{l-1}^{(n)}) + (1 - \delta)R_\pi(\mathcal{S} \mid \mathcal{Q}),\tag{6} Q(ol(n))=δRπ(nl∣ol−1(n))+(1−δ)Rπ(S∣Q),(6)
其中 ( \delta ) 是一个介于 (0, 1) 之间的参数,用于平衡过程得分和总得分,通常设为0.5。算法通过更新从叶节点到根节点的所有节点的Q值进行回溯。
Q ( o t ( n ) ) = max j = 1 n ( ∑ i = l t Q ( o i ( j ) ) l − t ) (7) Q(o_t^{(n)}) = \max_{j=1}^n \left( \frac{\sum_{i=l}^t Q(o_i^{(j)})}{l - t} \right) \tag{7} Q(ot(n))=j=1maxn(l−t∑i=ltQ(oi(j)))(7)
一个大规模数据库支持的文本到SQL基准. In 第37届神经信息处理系统国际会议论文集, NIPS '23, 纽约州红钩, 美国. Curran Associates Inc.
阿联酋阿布扎比, 2022年12月7-11日, 页602–631. 计算语言学协会.
参考 Paper:https://arxiv.org/pdf/2502.11741