InstructGPT高效实践——【DeepSpeed-Chat】源码详解(1/3):基本概念、数据集管理

目录

  • 前言
  • 0 基本概念与数据集设计
    • 0.1 InstructGPT提出的训练三段式
    • 0.2 DeepSpeed-Chat的数据集设计
      • 0.2.1 数据格式基本概念
      • 0.2.2 DeepSpeed-Chat的数据读取流
      • 0.2.3 关键代码详解
        • 0.2.3.1 自定义PromptRawDataset类
        • 0.2.3.2 阶段数据集处理过程
    • 0.3 版块相关问题
  • 后续

前言

InstructGPT高效实践——【DeepSpeed-Chat】源码详解(1/3):基本概念、数据集管理_第1张图片

  早些时候微软发布了遵从InstructGPT训练逻辑的训练框架DeepSpeed-Chat,旨在通过良好的DeepSpeed生态降低类ChatGPT模型昂贵的训练成本,为了能更直接地理解有关技术原理,我对其中实现训练相关的代码进行了详细剖析,考虑到目前还没有太多相关文章对此进行过深入介绍,因此我将在本博客中探讨这个框架的实现细节,以帮助有需要的人能更好地理解和使用它。另外,我也非常欢迎大家在评论区分享出自己对这个框架的看法以及使用经验,或是提出对本文的建议。

框架源码地址:https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat

  如果你只是想要知道如何使用自定义的数据集快速上手DeepSpeed-Chat训练,该系列内容对你来说可能过于繁杂,那么完全可以期待一下我后续将要更新的快速上手引导(已经在新建文件夹了哈哈)。
  如果你只对具体的训练细节感兴趣,该篇主要讲述的数据集内容可能不是你所想要了解的,请直接跳至【中篇】进行阅读。

  本文将根据DeepSpeed-Chat的数据集设计以及三个训练阶段(可分别简称phase1、phase2、phase3)共四个部分,将主要内容大体划分为四个版块、三个篇章,而每个版块都会以动态的时序图视角展示一套完整的工作流,然后深入详解其中与InstructGPT所述理论相关的代码实现,最后还将针对相关阶段涉及的一些具体问题加以阐述,以此完成对一个阶段源码的解析。在阅读过程中如果发现某些环节使你产生困惑,不妨跳转至【版块相关问题】,或许可以从中获得启发,如果你无法通过该部分找到所需答案,请随时留下你的评论,以便进行共同交流。

  此外,本文的重点在于源码解析,其中所涉及的ChatGPT背景知识、原理等将不再做过多推导式叙述。倘若你有原理基础,那么这篇文章肯定能够让你对各种相关原理如具体的RM结构、RM的训练方式、PPO迭代的具体顺序等实现细节拥有更加深刻的理解,并获得一定的实践启发;但假如你刚开始了解这项技术,也不必担心,我会使用尽可能简练的描述、尽可能直球的例子来对相应的部分进行说明。

  本篇为上中下三篇章中的【上篇】,主要针对DeepSpeed的数据集管理进行介绍。DeepSpeed提供了良好的数据流管道对数据集进行了规范化处理和标准化操作,用户在了解其中的细节后可以更加高效地实现模型训练。

0 基本概念与数据集设计

  现有的训练框架多数都是基于InstructGPT论文中所介绍的pipeline来实现,但一些更具体的细节,比如数据集的处理、奖励取值设计等,论文中没有进一步阐述,故而不同框架在某些细节的实现上会存在些许差异,因此在开始尝试使用DeepSpeed-Chat前我认为还是有必要了解一些框架内部既定的“范式”,这对后续理解某些具体细节将会有所帮助。

0.1 InstructGPT提出的训练三段式

  InstructGPT提出了大型问答模型的训练范式,分别为有监督微调训练、基于人类偏好的奖励模型训练、基于人类偏好的强化学习训练,最终模型将因此具备“根据用户输入,生成符合人类偏好答复”的能力。

阶段 相关模型 赋能
0 具备基本生成能力的基座模型(通常为CausalLM) -
1 有监督微调模型(SFT) 使用“prompt-response”数据(通俗含义上的“问答对”)对基座进行训练,基座将获得“根据指令生成出对应响应”的能力。
2 奖励模型(RM) 使用具有偏好评价的“prompt-response”数据(问答对)以及排序目标对预选模型进行训练,将得到具备“为指令数据做出人类偏好分值评估”能力的奖励模型。
3 SFT、Actor、RM、Critic 使用“prompt”数据(通俗含义上的“问句”),以第1、第2阶段训练得到的模型作为基线进行强化学习训练,最终得到具备“根据用户输入,生成符合人类偏好答复”能力的Actor模型。

0.2 DeepSpeed-Chat的数据集设计

  上述各阶段训练的一个关键就是数据集的设计,每个阶段需要使用的数据格式都不尽相同,DeepSpeed-Chat根据其中存在的共性,将数据集设计成了相对统一的形式,由Dataset类进行统一管理,然后再根据不同的训练阶段细化处理出相应的数据格式。

0.2.1 数据格式基本概念

数据格式名称 说明 样例
prompt 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3 "Human: Please tell me about Microsoft in a few sentence? Assistant: "
response/answer 对上下文prompt的响应、回答、应答,可以理解为通俗含义上的“答句”。 “Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.”
chosen 应答中的一种,人类所偏好的应答。 “Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.”
rejected 应答中的一种,人类所排斥的应答。 “I’m not sure what you mean.”
conversation 完整对话,由prompt衔接应答response得到。 Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.”
chosen_sentence 人类偏好的完整对话,由prompt衔接偏好应答chosen得到,适用于phase1和phase2 Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.”
reject_sentence 人类排斥的完整对话,由prompt衔接排斥应答rejected得到,适用于phase2 Human: Please tell me about Microsoft in a few sentence? Assistant: I’m not sure what you mean.”
unsup 无监督语料,符合自然语言要素的文本,适用于自回归语言模型的无监督训练 "Wikipedia is a free online encyclopedia that is maintained and edited collaboratively by volunteers from around the world. It contains articles on a wide range of topics, from history and science to popular culture and current events. Anyone can create, edit, or contribute to Wikipedia articles, making it an open and decentralized platform for knowledge sharing and dissemination. One of the key features of Wikipedia is its commitment to neutrality, with contributors striving to present information in an objective and unbiased manner. "

  DeepSpeed-Chat设计的数据格式是直接服务于阶段训练的:

  • phase1:采用chosen_sentence作为训练数据,进行自回归语言建模训练。chosen_sentence在通俗含义上代表“有效的问答数据”,有助于模型学习到理解指令并做出正确响应的能力。而reject_sentence作为“相对无效的问答数据”,其响应部分往往是“反人类”的,并不利于模型进行学习,因此在这个阶段采用了chosen_sentence作为训练数据。
  • phase2:采用chosen_sentence和reject_sentence作为训练数据,进行成对排序训练(pairwise ranking loss),chosen_sentence和reject_sentence将分别作为成对数据中的较好者和较差者被送入模型中,模型将学习到其中的排序思路,从而给出更为合理的奖励评分。这部分其实与InstructGPT中所述有些差别,InstructGPT是针对同个prompt构造了更多的conversations(如4至7个),通过排列组合的方式,这些conversations将两两组成更多的成对数据被送入模型中进行训练。总的来说,DeepSpeed-Chat与InstructGPT的训练思想是一致的。
  • phase3:采用prompt作为基本数据,调用中间模型(Actor、SFT、Critic、RM)根据基本数据构造出经验数据,使用强化学习中的PPO算法进行训练。更具体的内容可见相关代码解析。
  • 无监督训练:采用无监督语料数据,进行无监督的自回归语言建模训练。InstructGPT提出,进行phase3的RLHF训练时,为使得模型在学习人类偏好的过程中仍能保有预训练模型解决任务的性能,引入了传统的自回归语言建模进行联合训练。

  并且假如你确定要进行完整的三阶段训练,DeepSpeed-Chat鼓励使用可以同时适应三阶段的数据集,即同时具备prompt、chosen、rejected的数据集,他们认为在不同阶段中,使用不同的数据集则面临着数据分布差异问题,尤其是第二、第三阶段使用与第一阶段不同的数据集,这将可能导致第二、第三阶段训练出的模型质量变差。

One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.

  基于上述情况(三阶段使用同一数据集)考虑,DeepSpeed-Chat在构建数据集时提供了一个叫做“data_split”的传参,当你使用一个适用三阶段的数据集时,通过对该传参进行列表赋值,可以对三阶段数据比例进行设置,如“[6,2,2]”,这表示对于当前数据集,将分配全量数据的 6 6 + 2 + 2 \frac{6}{6+2+2} 6+2+26提供给第一阶段用于训练、验证,同理,分配全量的 2 6 + 2 + 2 \frac{2}{6+2+2} 6+2+22提供给第二阶段用于训练、验证,分配全量的 2 6 + 2 + 2 \frac{2}{6+2+2} 6+2+22提供给第三阶段用于训练、验证。

0.2.2 DeepSpeed-Chat的数据读取流

在此简单讲述UML时序图的元素含义:
- 箭头表示信息传递:实线表示调用,虚线表示返回;
- alt表示假设分支,其后方“[]”中的内容表示“条件”;
- loop表示循环;
- 淡蓝色区域即为高亮部分。
main.py data_utils.py raw_datasets.py create_prompt_dataset() 1 create_dataset() 2 get_raw_dataset() 3 PromptRawDataset 4 PromptRawDataset 5 raw_dataset 6 get_raw_dataset_split_index() 7 save(index) 8 alt [cache_inde- x==False] index 9 create_dataset_split() 10 PromptDataset 11 PromptDataset 12 save(PromptDataset) 13 alt [cache_found==False] PromptDataset 14 main.py data_utils.py raw_datasets.py

  总的来说,在训练的主流程代码main.py中,供DataLoader调用的Dataset(PromptDataset)将通过函数“create_prompt_dataset()”进行获取,其中将涉及到预存机制:如果已经存在预存数据,则直接通过torch.load()进行载入得到Dataset;如果未存在预存数据,则需要进行一系列预处理和预存的操作。获取Dataset的过程大致为(“括号序号”与UML时序图的“圈序号”对应):

  1. 查看是否存在数据集缓存(1):
    1. 如果存在则直接读取返回(14);
    2. 如果不存在则构建缓存(2-13):
      1. 读取全量数据集PromptRawDataset(3-6);
      2. 查看是否缓存有切分后的index(该index可作为索引,从原始数据中取出对应数据构成子数据集),如果有则直接读取返回,如果没有则进行构建(此处并不十分重要,故不再加以叙述)(7-9);
      3. 根据index从全量数据集中取出子数据集,将子数据集进一步处理成对应phase所需的格式(如chosen_sentence、reject_sentence等),并且使用tokenizer提前进行encode,将encode后的内容使用Dataset类进行维护,得到最终所需的Dataset实例PromptDataset(10-12);
      4. 将Dataset实例进行存储(13)。

0.2.3 关键代码详解

上述过程存在几个值得关注的地方(即文字描述加粗、UML时序图高亮的部分):

  • 全量数据集PromptRawDataset类的定义,源码中预定义了几个开源数据集的PromptRawDataset类,当用户想要使用自定义数据集进行训练时,就必然需要自行编写与数据集情况相关的PromptRawDataset类,从而实现自定义数据的读取;
  • 另一部分则是子数据集的处理过程,理解子数据集的处理过程及其形式,将能更加透彻地理解各个阶段模型所需的输入形式。

以下将对两个部分的源码进行详细介绍。

0.2.3.1 自定义PromptRawDataset类

UML时序图(3-6)

# applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py
class PromptRawDataset(object):

    def __init__(self, output_path, seed, local_rank, dataset_name):
        """
        初始化
        :param output_path: 输出缓存路径。
        :param seed: 随机种子。
        :param local_rank: 当前进程序号。
        :param dataset_name: 数据集名称,后续指定所需读取的数据集时将以名称为准。
        """
        self.dataset_name = dataset_name
        self.dataset_clean_name = dataset_clean_name
        self.output_path = output_path
        self.seed = seed
        self.local_rank = local_rank
        # load_dataset源自datasets库,该方法支持读取csv/json/text等多种文件格式的数据
        self.raw_datasets = load_dataset(dataset_name)

    def get_train_data(self):
        """
        获取训练集
        :return: dataset数据格式
        """
        return

    def get_eval_data(self):
        """
        获取验证集
        :return: dataset数据格式
        """
        return

    # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
    def get_prompt(self, sample):
        """
        从dataset的sample(单个样本)中获取prompt。
        :param sample: dataset的元素
        :return: prompt。prompt的格式必须为 "Human: {} Assistant:".format(actual_prompt_sentence)
        """
        return

    # The chosen response should be in the format of: " " + actual_response_sentence
    def get_chosen(self, sample):
        """
        从dataset的sample(单个样本)中获取chosen。chosen实际上是“chosen response”,指的是“精选的回复”,即人类所偏好的、高分的回复。
        :param sample: dataset的元素
        :return: chosen。chosen的格式必须为" {}".format(actual_response_sentence)
        """
        return

    # The rejected response should be in the format of: " " + actual_response_sentence
    # If the dataset does not have rejected response, return None
    def get_rejected(self, sample):
        """
        从dataset的sample(单个样本)中获取rejected。rejected实际上是“rejected response”,指的是“排斥的回复”,即人类所厌恶的、低分的回复。
        :param sample: dataset的元素
        :return: rejected。如果数据集中不存在则返回为None;如果存在,则其格式必须为 " {}".format(actual_response_sentence)
        """
        return

    def get_prompt_and_chosen(self, sample):
        """
        从dataset的sample(单个样本)中获取prompt与chosen。
        :param sample: dataset的元素
        :return: prompt与chosen的衔接。同样需要满足上述格式要求,即衔接结果为
        "Human: {} Assistant: {}".format(actual_prompt_sentence, actual_response_sentence)
        """
        return

    def get_prompt_and_rejected(self, sample):
        """
        从dataset的sample(单个样本)中获取prompt与rejected。
        :param sample: dataset的元素
        :return: prompt与rejected的衔接。同样需要满足上述格式要求,即衔接结果为
        "Human: {} Assistant: {}".format(actual_prompt_sentence, actual_response_sentence)
        """
        return

  自定义的数据集可以继承自上述的“PromptRawDataset”类,例如class CustomDataset(PromptRawDataset),然后重写其中的self.dataset_nameself.dataset_clean_name,此处的“dataset_name”即为传参指定数据集时所要填写的名称,例如self.dataset_name=custom,在设置传参--data_path=‘custom’时,将会读取到CustomDataset的数据用于进行训练。另外其中的get_train_data()等实例函数也需要进行重写,主要是实现将原始数据处理成注释所提及格式。
  定义好自定义PromptRawDataset后,还需要对其进行“注册”,具体可见下述代码块。

# applications/DeepSpeed-Chat/training/utils/data/data_utils.py
def get_raw_dataset(dataset_name, output_path, seed, local_rank):

    if "Dahoas/rm-static" in dataset_name:
        return raw_datasets.DahoasRmstaticDataset(output_path, seed,
                                                  local_rank, dataset_name)
    elif "Dahoas/full-hh-rlhf" in dataset_name:
        return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,
                                                    local_rank, dataset_name)
    ···
    """
    将自定义的PromptRawDataset在此处进行注册
    届时在传参“--data_path”中赋值“custom”即可读取到相应的数据集
	"""
    elif "custom" in dataset_name:
    	return raw_datasets.CustomDataset(output_path, seed,
                                          local_rank, dataset_name)
    else:
      raise RuntimeError(
          f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
      )

  至此完成自定义数据集的设置。理论上来说,只要实例函数能完全按照注释要求对原始数据进行处理,那么后续的数据流基本也无需再进行任何额外修改也能顺畅运行了。

0.2.3.2 阶段数据集处理过程

UML时序图(10-12)
  这部分处理得到的数据形式,基本接近于数据传入阶段模型前的最终形式,因此通过理解这部分的数据处理过程,可以直接了解到模型所需要的输入形式。

# applications/DeepSpeed-Chat/training/utils/data/data_utils.py
def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
                         end_of_conversation_token, max_seq_len):
    """
    将根据不同的阶段(train_phase)对数据集进行处理,主要是调用原先在PromptRawDataset类中定义的实例函数来实现。
    """
    prompt_dataset = []
    chosen_dataset = []
    reject_dataset = []
    if train_phase == 1:
        # 因为phase1只需要用到chosen数据,所以只取chosen进行处理
        for i, tmp_data in enumerate(current_dataset):
            # 获取chosen_sentence,即是将prompt和chosen拼接起来形成完整对话
            # 具体样例可参照“数据格式基本概念”中的样例
            chosen_sentence = raw_dataset.get_prompt_and_chosen(
                tmp_data)
            if chosen_sentence is not None:
            	# 在对话末尾加入对话终止符
                chosen_sentence += end_of_conversation_token
                # 使用tokenizer处理chosen_sentence,采取截断truncation
                chosen_token = tokenizer(chosen_sentence,
                                         max_length=max_seq_len,
                                         padding="max_length",
                                         truncation=True,
                                         return_tensors="pt")
                # 去掉batch维度
                chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(
                    0)
                chosen_token["attention_mask"] = chosen_token[
                    "attention_mask"].squeeze(0)
                # 存储tokenize结果至列表chosen_dataset
                chosen_dataset.append(chosen_token)

    elif train_phase == 2:
        # phase2需要用到chosen_sentence和reject_sentence
        # 所以需要对两者都进行处理
        for i, tmp_data in enumerate(current_dataset):
            # 获取chosen_sentence,即是将prompt和chosen拼接起来形成完整对话
            # 具体样例可参照“数据格式基本概念”中的样例
            chosen_sentence = raw_dataset.get_prompt_and_chosen(
                tmp_data)  # the accept response
            # 获取reject_sentence,即是将prompt和rejeced拼接起来形成完整对话
            # 具体样例可参照“数据格式基本概念”中的样例
            reject_sentence = raw_dataset.get_prompt_and_rejected(
                tmp_data)
            if chosen_sentence is not None and reject_sentence is not None:
            	# 在对话末尾加入对话终止符
                chosen_sentence += end_of_conversation_token  # the accept response
                reject_sentence += end_of_conversation_token
                # 使用tokenizer处理,采取截断truncation
                chosen_token = tokenizer(chosen_sentence,
                                         max_length=max_seq_len,
                                         padding="max_length",
                                         truncation=True,
                                         return_tensors="pt")
                reject_token = tokenizer(reject_sentence,
                                         max_length=max_seq_len,
                                         padding="max_length",
                                         truncation=True,
                                         return_tensors="pt")
                chosen_token["input_ids"] = chosen_token["input_ids"]
                chosen_token["attention_mask"] = chosen_token["attention_mask"]
                # 存储tokenize结果至列表chosen_dataset
                chosen_dataset.append(chosen_token)

                reject_token["input_ids"] = reject_token["input_ids"]
                reject_token["attention_mask"] = reject_token["attention_mask"]
                # 存储tokenize结果至列表reject_dataset
                reject_dataset.append(reject_token)

    elif train_phase == 3:
        # phase3用到prompt,prompt将被用来生成经验数据
        for i, tmp_data in enumerate(current_dataset):
        	# 直接获取prompt
        	# 具体样例可参照“数据格式基本概念”中的样例
            prompt = raw_dataset.get_prompt(tmp_data)
            if prompt is not None:
                prompt_token = tokenizer(prompt, return_tensors="pt")
                prompt_token["input_ids"] = prompt_token["input_ids"]
                prompt_token["attention_mask"] = prompt_token["attention_mask"]
                for key_word in ["input_ids", "attention_mask"]:
                    # 获取当前文本token的实际长度
                    length = prompt_token[key_word].size()[-1]
                    # phase3此处的max_seq_len其实是max_prompt_len,默认只有256
                    if length > max_seq_len:
                        # 如果当前文本token长度比max_prompt_len还长
                        # 那么就截断文本前面的部分,保留后面max_prompt_len长度的部分文本
                        # 然后将token进行flip(翻转/倒序),之后在data_collator中再将其flip回来
                        y = prompt_token[key_word].squeeze(0)[length -
                                                              (max_seq_len -
                                                               1):].flip(0)
                    else:
                        # 将token进行flip(翻转/倒序),之后在data_collator中再将其flip回来
                        y = prompt_token[key_word].squeeze(0).flip(0)
                    prompt_token[key_word] = y
                prompt_dataset.append(prompt_token)

    # 返回PromptDataset实例,该实例相当于torch中的Dataset,可供DataLoader调用
    return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
                         tokenizer.pad_token_id, train_phase)

  此处的处理部分很大程度依赖于原先所定义的PromptRawDataset实例函数,由此可见,只要正确编写实例函数,后续过程基本也不会出现什么问题。流程大致就是取出对应阶段所需的格式数据,然后使用tokenizer进行处理,综上所述:

  • phase1模型所需的输入数据为chosen_sentence的input_ids及attention_mask;
  • phase2模型所需的输入数据为chosen_sentence和reject_sentence的input_ids及attention_mask;
  • phase3模型所需的输入数据为promt的input_ids及attention_mask。

0.3 版块相关问题

  1. 【1.2.2.2 阶段数据集处理过程】中,为什么phase3要专门对prompt token进行flip(翻转)操作?
      这个额外操作很好解释,主要是便于进行前侧padding的操作。具体来说,phase3取用prompt的目的在于,将prompt输入至actor模型中,actor将根据prompt自回归地生成后续内容,以此进行经验采集。以基座为opt-125m的actor模型为例,该模型所能支持的最大序列长度(max_seq_len)为512,而phase3还将预设有最大prompt长度(max_prompt_len),通常为max_seq_len的一半,即256,余下的另一半长度将被用于生成。那么当输入的prompt不满足最大prompt长度max_prompt_len时,将需要对该prompt进行padding操作(将在后续phase3的data_collator代码中有所体现),而padding操作通常又是直接于序列后侧加入pad token,padding后的输入将变成[prompt, padding]的形式,自回归生成任务将接续pad_token进行生成——这是不合理的,因此需要先将prompt输入进行flip翻转,翻转后进行padding操作,然后再flip翻转回来,padding后的输入就成为了[padding, prompt]的形式,对于自回归任务来说,接续prompt的内容进行生成才是合理的。
      通过下述伪代码例子应该能更好地理解这个操作的用意。
max_prompt_len = 5
pad_token_id = 0

prompt_token_ids = [233, 11, 22]
# padding位于后侧 ×
prompt_token_ids.padding() = [233, 11, 22, 0, 0]

prompt_token_ids.flip(0) = [22, 11, 233]
prompt_token_ids.flip(0).padding() = [22, 11, 233, 0, 0]
# padding位于前侧 √
prompt_token_ids.flip(0).padding().flip(0) = [0, 0, 233, 11, 22]

后续

  关于训练阶段的具体内容可见【中篇】【下篇】。

你可能感兴趣的:(人工智能,自然语言处理,chatgpt,nlp,深度学习)