InstructGPT和ChatGPT 的训练流程基本一致 ,ChatGPT是改进后的InstructGPT,比如InstructGPT是基于GPT-3训练,而ChatGPT是基于GPT-3.5训练。
学习InstructGPT论文之前,想了解了基本的LLM或者RLHF流程,可以看看组织「eosphoros-ai」(今年的8000+star的开源项目DB-GPT的开源社区)提出的LLM+Text2SQL汇总项目:https://github.com/eosphoros-ai/Awesome-Text2SQL,里面也收集了一些微调SFT(lora, qlora, p-tuning等),RLHF相关的论文(比如RLHF,RRHF,RLTF, RRTF, RLAIF等等),目前也有300+的star,持续更新中,欢迎围观使用star!
背景
使语言模型更大并不能使它们更好地遵循用户的意图。例如,大型语言模型可能生成不真实的(untruthful)、有害的(toxic)或对用户没有帮助(not helpful)的输出。
贡献/方法
在本文中,作者展示了一种方法,通过使用人类反馈进行微调,在广泛的任务中使语言模型与用户意图保持一致。
结果:参数量小了100倍,性能差不多。 真实性⬆️、有毒⬇️、精度⬇️(轻微)
结果惊艳:
结论:
尽管InstructGPT仍然会犯一些简单的错误,但结果表明,根据人类反馈进行微调是使语言模型与人类意图保持一致的一个有希望的方向。
论文还在 public NLP dataset进行了实验,InstructGPT模型在公有NLP数据集上有“对齐税”导致性能下降,可能是因为API prompt 训练的原因。
论文还公布了qualitative results,InstructGPT模型泛化能力很强,具体实验参考原论文。
人类偏好,人类价值观 --> 标注者的偏好、OpenAI 研究人员的偏好、API 用户的偏好。
基础背景知识
这个图也是经典大图了,RLHF实践参考的范式,RLHF主要分成了3个阶段:
PPO算法具体是什么呢?——(留个坑,后续补上)
详情参考论文:Schulman, J., Wolski, F., Dhariwal, P., Radford, A., and Klimov, O. (2017). Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347.
更直观一点,以一个具体的小任务比如Text2SQL为例子,构造的数据集如下所示:
来源知乎文档:Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}
参数如下:
选择最终的SFT模型时,是根据验证集上的RM分数。
惊讶点:
同样的,以Text2SQL任务为例子,构造的数据集如下所示:
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","chosen": "SELECT count(*) FROM head WHERE age > 56","rejected":"SELECT COUNT(head_name) FROM head WHERE age > 56;"}
为什么RM模型选6B,不是175B?
损失函数:
最后要对奖励归一化,使得平均奖励为0。
和SFT阶段数据格式一致。
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}
1.RM可以和RL重复多轮迭代——这样构建更多数据,越来越趋近于人类偏好。
2.实践中,大部分的比较数据来源于SFT的数据,少部分数据来源于RL模型的比较数据。
为什么用π表示?为什么用除法表示?这就是强化学习的基本概念
从状态State到动作Action的过程就称之为一个策略Policy,一般用π表示(可以理解为一个函数表示),也就是在强化学习阶段需要找到一个关系:a=π(s) 或者是 π(a|s), a 就是action, s就是state
之前听一个大学教授的讲座,有个观点很有意思:Open AI做大模型为什么比谷歌强,因为包括transformer在内的一些创新模型大多是谷歌研究的,那为什么Open AI在大模型领域为什么比谷歌强?答:因为Open AI在数据清洗,数据质量把控这方面做的很好。——所以数据是相当重要的!
为了训练本文的最终InstructGPT
prompt dataset 主要由OpenAI 的API获得,用户和API交互,把这些数据收集起来(前提是用户使用的时候就告知数据要被收集),此时的API是早期的InstructGPT模型,并且没有使用用户在生产中使用API的数据。
API数据分布如下,主要有9类。
那么问题来了?早期的InstructGPT模型的训练数据怎么来?
对API收集的数据做了一些处理:
主要是为了训练早期的InstructGPT
标注者被要求手写以下三种类型的prompt:
数据中96%以上是英文,其它20个语种例如中文,法语,西班牙语等加起来不到4%,这可能导致InstructGPT/ChatGPT能进行其它语种的生成时,效果应该远不如英文
论文还有大量的附录数据详情,可以参考论文原文,比如标注人员分布,数据示例,数据标注等等,不得不说,Open AI数据扎实,正文20页,附录48页,总共68页。
Text-to-SQL小白入门(一)综述文章学习
Text-to-SQL小白入门(二)Transformer学习
Text-to-SQL小白入门(三)IRNet:引入中间表示SemQL
Text-to-SQL小白入门(四)指令进化大模型WizardLM
Text-to-SQL小白入门(五)开源代码大模型Code Llama
Text-to-SQL小白入门(六)Awesome-Text2SQL项目介绍
Text-to-SQL小白入门(七)PanGu-Coder2论文——RRTF
Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习