一图拆解RLHF中TRL的PPO

仔细看了看TRL的code(https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py),step大致流程为先batched_forward_pass,再过minibatch:
一图拆解RLHF中TRL的PPO_第1张图片
再写一写自己的理解:
PPO的loss由以下几部分相加得到:

  • actor的loss,代码里叫pg_loss,pg_loss是由-advantage值*exp(logprobs - old_logprobs)的得到的(刚好就是除的形式),至于advantage就是PPO里TD误差的加权和。这里值得注意的是PPO中的核心改进就是上面exp(logprobs - old_logprobs)控制在[1-\epi和1+\epi]之间做了CLIP。另外需要注意的是advantage是由rewards来得到的,rewards里增加了kl_penalty
  • critic的loss,代码里叫vf_loss(value function loss),我们都有一个由Critic预测出的预期回报值V,以及一个真实的回报值G(returns),这俩MSE就是critic的loss。值得注意的是由大batch得到的values和小batch得到的vpreds也做了clip,values就是由critic模型输出的那个values。values和logits都是由LLM输出,代码中会给LLM加一个PreTrainedModelWrapper
  • (optional) LM的loss,这个在TRL库里没有
  • (optional) actor里可能会加一个entropy loss来让每一步动作更加均匀,但是在TRL的实现中并没有找到这个。

PPO训练中loss下降不能代表全部问题,以下几个指标也经常关注:

  • reward系列: 希望reward可以平稳上升,reward->advantage->pg_loss
  • kl系列:期望和ref模型距离不要太远
  • PPL(perplexity系列):语言模型

PPO过程中有4个模型,这4个都可以放一个大语言模型,导致显存要占4倍,也可以做一些优化:

  • SFT Model: 一般会把LM做一遍SFT,保证Policy Model距离语言模型不要太远
  • Policy Model: 最终得到的LM
  • Reward Model(RM):怎么评估rewards,例如计算器简单来说就是结果对不对,对得1分,不对得0分;对于toxic来说就是有没有毒
  • Value Model:也可以叫Critic Model,评估当前状态下的期望收益。这个可以给Policy Model加一个Linear,省一些模型空间
    下图来自论文Secrets of RLHF in Large Language Models Part I: PPO:
    一图拆解RLHF中TRL的PPO_第2张图片

以下转载自:https://zhuanlan.zhihu.com/p/635757674

你可能感兴趣的:(深度学习,机器学习,人工智能)