早些时候微软发布了遵从InstructGPT训练逻辑的训练框架DeepSpeed-Chat,旨在通过良好的DeepSpeed生态降低类ChatGPT模型昂贵的训练成本,为了能更直接地理解有关技术原理,我对其中实现训练相关的代码进行了详细剖析,考虑到目前还没有太多相关文章对此进行过深入介绍,因此我将在本博客中探讨这个框架的实现细节,以帮助有需要的人能更好地理解和使用它。另外,我也非常欢迎大家在评论区分享出自己对这个框架的看法以及使用经验,或是提出对本文的建议。
框架源码地址:https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat
如果你只是想要知道如何使用自定义的数据集快速上手DeepSpeed-Chat训练,该系列内容对你来说可能过于繁杂,那么完全可以期待一下我后续将要更新的快速上手引导(已经在新建文件夹了哈哈)。
如果你只对具体的训练细节感兴趣,该篇主要讲述的数据集内容可能不是你所想要了解的,请直接跳至【中篇】进行阅读。
本文将根据DeepSpeed-Chat的数据集设计以及三个训练阶段(可分别简称phase1、phase2、phase3)共四个部分,将主要内容大体划分为四个版块、三个篇章,而每个版块都会以动态的时序图视角展示一套完整的工作流,然后深入详解其中与InstructGPT所述理论相关的代码实现,最后还将针对相关阶段涉及的一些具体问题加以阐述,以此完成对一个阶段源码的解析。在阅读过程中如果发现某些环节使你产生困惑,不妨跳转至【版块相关问题】,或许可以从中获得启发,如果你无法通过该部分找到所需答案,请随时留下你的评论,以便进行共同交流。
此外,本文的重点在于源码解析,其中所涉及的ChatGPT背景知识、原理等将不再做过多推导式叙述。倘若你有原理基础,那么这篇文章肯定能够让你对各种相关原理如具体的RM结构、RM的训练方式、PPO迭代的具体顺序等实现细节拥有更加深刻的理解,并获得一定的实践启发;但假如你刚开始了解这项技术,也不必担心,我会使用尽可能简练的描述、尽可能直球的例子来对相应的部分进行说明。
本篇为上中下三篇章中的【上篇】,主要针对DeepSpeed的数据集管理进行介绍。DeepSpeed提供了良好的数据流管道对数据集进行了规范化处理和标准化操作,用户在了解其中的细节后可以更加高效地实现模型训练。
现有的训练框架多数都是基于InstructGPT论文中所介绍的pipeline来实现,但一些更具体的细节,比如数据集的处理、奖励取值设计等,论文中没有进一步阐述,故而不同框架在某些细节的实现上会存在些许差异,因此在开始尝试使用DeepSpeed-Chat前我认为还是有必要了解一些框架内部既定的“范式”,这对后续理解某些具体细节将会有所帮助。
InstructGPT提出了大型问答模型的训练范式,分别为有监督微调训练、基于人类偏好的奖励模型训练、基于人类偏好的强化学习训练,最终模型将因此具备“根据用户输入,生成符合人类偏好答复”的能力。
阶段 | 相关模型 | 赋能 |
---|---|---|
0 | 具备基本生成能力的基座模型(通常为CausalLM) | - |
1 | 有监督微调模型(SFT) | 使用“prompt-response”数据(通俗含义上的“问答对”)对基座进行训练,基座将获得“根据指令生成出对应响应”的能力。 |
2 | 奖励模型(RM) | 使用具有偏好评价的“prompt-response”数据(问答对)以及排序目标对预选模型进行训练,将得到具备“为指令数据做出人类偏好分值评估”能力的奖励模型。 |
3 | SFT、Actor、RM、Critic | 使用“prompt”数据(通俗含义上的“问句”),以第1、第2阶段训练得到的模型作为基线进行强化学习训练,最终得到具备“根据用户输入,生成符合人类偏好答复”能力的Actor模型。 |
上述各阶段训练的一个关键就是数据集的设计,每个阶段需要使用的数据格式都不尽相同,DeepSpeed-Chat根据其中存在的共性,将数据集设计成了相对统一的形式,由Dataset类进行统一管理,然后再根据不同的训练阶段细化处理出相应的数据格式。
数据格式名称 | 说明 | 样例 |
---|---|---|
prompt | 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3。 | "Human: Please tell me about Microsoft in a few sentence? Assistant: " |
response/answer | 对上下文prompt的响应、回答、应答,可以理解为通俗含义上的“答句”。 | “Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
chosen | 应答中的一种,人类所偏好的应答。 | “Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
rejected | 应答中的一种,人类所排斥的应答。 | “I’m not sure what you mean.” |
conversation | 完整对话,由prompt衔接应答response得到。 | “Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
chosen_sentence | 人类偏好的完整对话,由prompt衔接偏好应答chosen得到,适用于phase1和phase2。 | “Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
reject_sentence | 人类排斥的完整对话,由prompt衔接排斥应答rejected得到,适用于phase2。 | “Human: Please tell me about Microsoft in a few sentence? Assistant: I’m not sure what you mean.” |
unsup | 无监督语料,符合自然语言要素的文本,适用于自回归语言模型的无监督训练。 | "Wikipedia is a free online encyclopedia that is maintained and edited collaboratively by volunteers from around the world. It contains articles on a wide range of topics, from history and science to popular culture and current events. Anyone can create, edit, or contribute to Wikipedia articles, making it an open and decentralized platform for knowledge sharing and dissemination. One of the key features of Wikipedia is its commitment to neutrality, with contributors striving to present information in an objective and unbiased manner. " |
DeepSpeed-Chat设计的数据格式是直接服务于阶段训练的:
并且假如你确定要进行完整的三阶段训练,DeepSpeed-Chat鼓励使用可以同时适应三阶段的数据集,即同时具备prompt、chosen、rejected的数据集,他们认为在不同阶段中,使用不同的数据集则面临着数据分布差异问题,尤其是第二、第三阶段使用与第一阶段不同的数据集,这将可能导致第二、第三阶段训练出的模型质量变差。
One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.
基于上述情况(三阶段使用同一数据集)考虑,DeepSpeed-Chat在构建数据集时提供了一个叫做“data_split”的传参,当你使用一个适用三阶段的数据集时,通过对该传参进行列表赋值,可以对三阶段数据比例进行设置,如“[6,2,2]”,这表示对于当前数据集,将分配全量数据的 6 6 + 2 + 2 \frac{6}{6+2+2} 6+2+26提供给第一阶段用于训练、验证,同理,分配全量的 2 6 + 2 + 2 \frac{2}{6+2+2} 6+2+22提供给第二阶段用于训练、验证,分配全量的 2 6 + 2 + 2 \frac{2}{6+2+2} 6+2+22提供给第三阶段用于训练、验证。
在此简单讲述UML时序图的元素含义:
- 箭头表示信息传递:实线表示调用,虚线表示返回;
- alt表示假设分支,其后方“[]”中的内容表示“条件”;
- loop表示循环;
- 淡蓝色区域即为高亮部分。
总的来说,在训练的主流程代码main.py中,供DataLoader调用的Dataset(PromptDataset)将通过函数“create_prompt_dataset()”进行获取,其中将涉及到预存机制:如果已经存在预存数据,则直接通过torch.load()进行载入得到Dataset;如果未存在预存数据,则需要进行一系列预处理和预存的操作。获取Dataset的过程大致为(“括号序号”与UML时序图的“圈序号”对应):
上述过程存在几个值得关注的地方(即文字描述加粗、UML时序图高亮的部分):
以下将对两个部分的源码进行详细介绍。
UML时序图(3-6)
# applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py
class PromptRawDataset(object):
def __init__(self, output_path, seed, local_rank, dataset_name):
"""
初始化
:param output_path: 输出缓存路径。
:param seed: 随机种子。
:param local_rank: 当前进程序号。
:param dataset_name: 数据集名称,后续指定所需读取的数据集时将以名称为准。
"""
self.dataset_name = dataset_name
self.dataset_clean_name = dataset_clean_name
self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
# load_dataset源自datasets库,该方法支持读取csv/json/text等多种文件格式的数据
self.raw_datasets = load_dataset(dataset_name)
def get_train_data(self):
"""
获取训练集
:return: dataset数据格式
"""
return
def get_eval_data(self):
"""
获取验证集
:return: dataset数据格式
"""
return
# The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
def get_prompt(self, sample):
"""
从dataset的sample(单个样本)中获取prompt。
:param sample: dataset的元素
:return: prompt。prompt的格式必须为 "Human: {} Assistant:".format(actual_prompt_sentence)
"""
return
# The chosen response should be in the format of: " " + actual_response_sentence
def get_chosen(self, sample):
"""
从dataset的sample(单个样本)中获取chosen。chosen实际上是“chosen response”,指的是“精选的回复”,即人类所偏好的、高分的回复。
:param sample: dataset的元素
:return: chosen。chosen的格式必须为" {}".format(actual_response_sentence)
"""
return
# The rejected response should be in the format of: " " + actual_response_sentence
# If the dataset does not have rejected response, return None
def get_rejected(self, sample):
"""
从dataset的sample(单个样本)中获取rejected。rejected实际上是“rejected response”,指的是“排斥的回复”,即人类所厌恶的、低分的回复。
:param sample: dataset的元素
:return: rejected。如果数据集中不存在则返回为None;如果存在,则其格式必须为 " {}".format(actual_response_sentence)
"""
return
def get_prompt_and_chosen(self, sample):
"""
从dataset的sample(单个样本)中获取prompt与chosen。
:param sample: dataset的元素
:return: prompt与chosen的衔接。同样需要满足上述格式要求,即衔接结果为
"Human: {} Assistant: {}".format(actual_prompt_sentence, actual_response_sentence)
"""
return
def get_prompt_and_rejected(self, sample):
"""
从dataset的sample(单个样本)中获取prompt与rejected。
:param sample: dataset的元素
:return: prompt与rejected的衔接。同样需要满足上述格式要求,即衔接结果为
"Human: {} Assistant: {}".format(actual_prompt_sentence, actual_response_sentence)
"""
return
自定义的数据集可以继承自上述的“PromptRawDataset”类,例如class CustomDataset(PromptRawDataset)
,然后重写其中的self.dataset_name
及self.dataset_clean_name
,此处的“dataset_name”即为传参指定数据集时所要填写的名称,例如self.dataset_name=custom
,在设置传参--data_path=‘custom’
时,将会读取到CustomDataset的数据用于进行训练。另外其中的get_train_data()
等实例函数也需要进行重写,主要是实现将原始数据处理成注释所提及格式。
定义好自定义PromptRawDataset后,还需要对其进行“注册”,具体可见下述代码块。
# applications/DeepSpeed-Chat/training/utils/data/data_utils.py
def get_raw_dataset(dataset_name, output_path, seed, local_rank):
if "Dahoas/rm-static" in dataset_name:
return raw_datasets.DahoasRmstaticDataset(output_path, seed,
local_rank, dataset_name)
elif "Dahoas/full-hh-rlhf" in dataset_name:
return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,
local_rank, dataset_name)
···
"""
将自定义的PromptRawDataset在此处进行注册
届时在传参“--data_path”中赋值“custom”即可读取到相应的数据集
"""
elif "custom" in dataset_name:
return raw_datasets.CustomDataset(output_path, seed,
local_rank, dataset_name)
else:
raise RuntimeError(
f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
)
至此完成自定义数据集的设置。理论上来说,只要实例函数能完全按照注释要求对原始数据进行处理,那么后续的数据流基本也无需再进行任何额外修改也能顺畅运行了。
UML时序图(10-12)
这部分处理得到的数据形式,基本接近于数据传入阶段模型前的最终形式,因此通过理解这部分的数据处理过程,可以直接了解到模型所需要的输入形式。
# applications/DeepSpeed-Chat/training/utils/data/data_utils.py
def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
end_of_conversation_token, max_seq_len):
"""
将根据不同的阶段(train_phase)对数据集进行处理,主要是调用原先在PromptRawDataset类中定义的实例函数来实现。
"""
prompt_dataset = []
chosen_dataset = []
reject_dataset = []
if train_phase == 1:
# 因为phase1只需要用到chosen数据,所以只取chosen进行处理
for i, tmp_data in enumerate(current_dataset):
# 获取chosen_sentence,即是将prompt和chosen拼接起来形成完整对话
# 具体样例可参照“数据格式基本概念”中的样例
chosen_sentence = raw_dataset.get_prompt_and_chosen(
tmp_data)
if chosen_sentence is not None:
# 在对话末尾加入对话终止符
chosen_sentence += end_of_conversation_token
# 使用tokenizer处理chosen_sentence,采取截断truncation
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
# 去掉batch维度
chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(
0)
chosen_token["attention_mask"] = chosen_token[
"attention_mask"].squeeze(0)
# 存储tokenize结果至列表chosen_dataset
chosen_dataset.append(chosen_token)
elif train_phase == 2:
# phase2需要用到chosen_sentence和reject_sentence
# 所以需要对两者都进行处理
for i, tmp_data in enumerate(current_dataset):
# 获取chosen_sentence,即是将prompt和chosen拼接起来形成完整对话
# 具体样例可参照“数据格式基本概念”中的样例
chosen_sentence = raw_dataset.get_prompt_and_chosen(
tmp_data) # the accept response
# 获取reject_sentence,即是将prompt和rejeced拼接起来形成完整对话
# 具体样例可参照“数据格式基本概念”中的样例
reject_sentence = raw_dataset.get_prompt_and_rejected(
tmp_data)
if chosen_sentence is not None and reject_sentence is not None:
# 在对话末尾加入对话终止符
chosen_sentence += end_of_conversation_token # the accept response
reject_sentence += end_of_conversation_token
# 使用tokenizer处理,采取截断truncation
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
reject_token = tokenizer(reject_sentence,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
chosen_token["input_ids"] = chosen_token["input_ids"]
chosen_token["attention_mask"] = chosen_token["attention_mask"]
# 存储tokenize结果至列表chosen_dataset
chosen_dataset.append(chosen_token)
reject_token["input_ids"] = reject_token["input_ids"]
reject_token["attention_mask"] = reject_token["attention_mask"]
# 存储tokenize结果至列表reject_dataset
reject_dataset.append(reject_token)
elif train_phase == 3:
# phase3用到prompt,prompt将被用来生成经验数据
for i, tmp_data in enumerate(current_dataset):
# 直接获取prompt
# 具体样例可参照“数据格式基本概念”中的样例
prompt = raw_dataset.get_prompt(tmp_data)
if prompt is not None:
prompt_token = tokenizer(prompt, return_tensors="pt")
prompt_token["input_ids"] = prompt_token["input_ids"]
prompt_token["attention_mask"] = prompt_token["attention_mask"]
for key_word in ["input_ids", "attention_mask"]:
# 获取当前文本token的实际长度
length = prompt_token[key_word].size()[-1]
# phase3此处的max_seq_len其实是max_prompt_len,默认只有256
if length > max_seq_len:
# 如果当前文本token长度比max_prompt_len还长
# 那么就截断文本前面的部分,保留后面max_prompt_len长度的部分文本
# 然后将token进行flip(翻转/倒序),之后在data_collator中再将其flip回来
y = prompt_token[key_word].squeeze(0)[length -
(max_seq_len -
1):].flip(0)
else:
# 将token进行flip(翻转/倒序),之后在data_collator中再将其flip回来
y = prompt_token[key_word].squeeze(0).flip(0)
prompt_token[key_word] = y
prompt_dataset.append(prompt_token)
# 返回PromptDataset实例,该实例相当于torch中的Dataset,可供DataLoader调用
return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
tokenizer.pad_token_id, train_phase)
此处的处理部分很大程度依赖于原先所定义的PromptRawDataset实例函数,由此可见,只要正确编写实例函数,后续过程基本也不会出现什么问题。流程大致就是取出对应阶段所需的格式数据,然后使用tokenizer进行处理,综上所述:
max_prompt_len = 5
pad_token_id = 0
prompt_token_ids = [233, 11, 22]
# padding位于后侧 ×
prompt_token_ids.padding() = [233, 11, 22, 0, 0]
prompt_token_ids.flip(0) = [22, 11, 233]
prompt_token_ids.flip(0).padding() = [22, 11, 233, 0, 0]
# padding位于前侧 √
prompt_token_ids.flip(0).padding().flip(0) = [0, 0, 233, 11, 22]
关于训练阶段的具体内容可见【中篇】【下篇】。