X-R1 项目代码文件的详细剖析并精读rewards、grpo、x_grpo_trainer(src/x_r1)

这个项目名为[X-R1](https://github.com/dhcode-cpp/X-R1),是一个基于强化学习的训练框架,旨在构建一个易于使用、低成本的训练框架,以加速Scaling Post-Training的开发。以下是对该项目的详细解释:

项目结构

项目的主要目录结构如下:

X-R1/
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── requirements.txt
├── setup.cfg
├── setup.py
├── src/
│   └── x_r1/
├── README.assets/
│   ├── X-R1-0.5B-acc-result.png
│   ├── X-R1-log.png
│   └── aha_moment_0.5B.png
└── recipes/
    ├── README.md
    ├── X_R1_test_env_single.yaml
    ├── X_R1_zero_0dot5B_config.yaml
    ├── X_R1_zero_1dot5B_config.yaml
    ├── X_R1_zero_7B_config.yaml
    ├── zero1.yaml
    ├── zero2.yaml
    └── zero3.yaml
  • .gitignore:指定了Git需要忽略的文件和目录。
  • LICENSE:项目的许可证文件。
  • Makefile:包含了一些常用的构建和运行命令。
  • README.md:项目的说明文档,包含项目的介绍、功能、安装方法等信息。
  • requirements.txt:项目的依赖库列表,使用pip install -r requirements.txt可以安装所有依赖。
  • setup.cfgsetup.py:用于打包和分发项目的配置文件。
  • src/x_r1/:项目的源代码目录。
  • README.assets/:存放项目说明文档中使用的图片等资产文件。
  • recipes/:包含了不同的训练配置文件,如X_R1_test_env_single.yamlX_R1_zero_0dot5B_config.yaml等。

主要功能和特性

  1. 低成本训练

    • 4x3090/4090 GPUs训练1小时,成本小于10美元,10分钟37步即可输出“Aha Moment”。
    • 支持0.5B规模的模型进行强化学习训练,并且可以支持更大规模的模型,如1.5B/7B/32B等。
  2. 数据集支持

    • 提供了0.75k/1.5k/7.5k的数据集,用于快速训练循环。
  3. 日志记录

    • 记录GRPO在线采样数据到日志文件。

配置文件

项目中的配置文件(如X_R1_zero_0dot5B_config.yaml等)主要包含以下几部分的配置:

  • 模型参数:如model_name_or_pathmodel_revisiontorch_dtype等,指定了模型的名称、版本和数据类型。
  • 数据训练参数:如dataset_namedataset_configsnum_processes等,指定了训练使用的数据集和进程数。
  • GRPO训练器配置:如use_vllmoutput_dirgradient_accumulation_steps等,配置了训练器的相关参数。

安装和运行

  1. 安装依赖
    • 首先需要安装CUDA版本大于12.4。
    • 创建并激活一个名为xr1的Python 3.11虚拟环境:
conda create -n xr1 python=3.11
conda activate xr1
- 安装项目所需的依赖库:
pip install -r requirements.txt
- 创建输出目录:
mkdir output
  1. 运行示例
    • 单GPU运行
ACCELERATE_LOG_LEVEL=info \
accelerate launch \
--config_file recipes/zero1.yaml \
--num_processes=1 \
src/x_r1/grpo.py \
--config recipes/X_R1_test_env_single.yaml \
> ./output/x_r1_test_sampling.log 2>&1
- **多GPU运行**:
ACCELERATE_LOG_LEVEL=info \
accelerate launch \
--config_file recipes/accelerate_configs/zero3.yaml \
--num_processes=1 \
src/x_r1/grpo.py \
--config recipes/x_r1_test_sampling.yaml \
> ./output/test.log 2>&1

代码文件

  • register_lighteval_task函数:位于X-R1/src/x_r1/utils/evaluation.py文件中,用于注册LightEval任务配置。
  • push_to_hub_revision函数:位于X-R1/src/x_r1/utils/hub.py文件中,用于将模型推送到Hub仓库的指定分支。

相关链接

项目受到了DeepSeek-R1和open-r1的启发。

代码解读:

x_grpo_trainer.py剖析

1. 模块导入
import os
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Union
from unittest.mock import patch

import torch
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather_object
from transformers import (
    PreTrainedModel,
    Trainer,
)
from trl.trainer import GRPOTrainer
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import  pad
  • 导入了一系列必要的模块,包括ostorchtransformers等,这些模块用于文件操作、深度学习计算、数据处理等。
  • trl库中导入了GRPOTrainerGRPOConfig等相关类和函数,表明该文件可能是在trl库的基础上进行扩展。
2. 奖励函数类型定义
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
  • 定义了一个RewardFunc类型别名,它可以是字符串(模型ID)、预训练模型或一个可调用对象(接受提示和完成列表并返回奖励列表)。
3. XGRPOTrainer类定义
class XGRPOTrainer(GRPOTrainer):
    # base trl GRPO_trainer
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
  • XGRPOTrainer类继承自GRPOTrainer,并重写了compute_loss方法。
  • 如果return_outputsTrue,会抛出ValueError,表明该训练器不支持返回输出。
4. 数据处理和生成
        device = self.accelerator.device
        prompts = [x["prompt"] for x in inputs]
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
        prompt_inputs = self.processing_class(
            prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
        )
        prompt_inputs = super()._prepare_inputs(prompt_inputs)

        if self.max_prompt_length is not None:
            prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
            prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
  • 获取设备信息。
  • 从输入中提取提示信息,并应用聊天模板(如果需要)。
  • 使用processing_class对提示进行处理,生成输入张量。
  • 如果设置了max_prompt_length,则对输入进行截断。
5. 生成完成信息
        # Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
                    state_dict = unwrapped_model.state_dict()
                if self.accelerator.is_main_process:
                    llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                    llm_model.load_weights(state_dict.items())
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            all_prompts_text = gather_object(prompts_text)
            if self.accelerator.is_main_process:
                outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
                completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
                for output in outputs:
                    print('-'*100)
                    print('\n\n\n')
                    prompt = output.prompt
                    for output_t in  output.outputs:
                        # print(completion_ids)
                        print('='*100)
                        generated_text = output_t.text
                        print("【USER】: ", prompt )
                        print("\n【ASSISTANT】:", generated_text)
            else:
                completion_ids = [None] * len(all_prompts_text) * self.num_generations

            # Broadcast the completions from the main process to all processes, ensuring each process receives its
            # corresponding slice.
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(
                self.accelerator.process_index * len(prompts) * self.num_generations,
                (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
            )
            completion_ids = completion_ids[process_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0).to(device)
            prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
        else:
            # Regular generation path
            with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
                prompt_inputs['input_ids'].to(device)
                prompt_inputs['attention_mask'].to(device)

                prompt_completion_ids = unwrapped_model.generate(
                    **prompt_inputs, generation_config=self.generation_config
                )
  • 根据self.args.use_vllm的值,选择使用vLLM或常规生成方法生成完成信息。
  • 如果使用vLLM,需要在主进程中加载权重,并在主进程中生成完成信息,然后将结果广播到所有进程。
  • 对生成的完成信息进行填充和拼接。
6. 计算损失
        prompt_length = prompt_inputs["input_ids"].size(1)
        completion_ids = prompt_completion_ids[:, prompt_length:]

        # Get the per-token log probabilities for the completions for the model and the reference model
        def get_per_token_logps(model, input_ids, num_logits_to_keep):
            # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
            logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits  # (B, L, V)
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

            # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
            per_token_logps = []
            for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
                log_probs = logits_row.log_softmax(dim=-1)
                token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
                per_token_logps.append(token_log_prob)
            return torch.stack(per_token_logps)

        num_logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
        per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)

        with torch.inference_mode():
            if self.ref_model is not None:
                ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep)
            else:
                with self.accelerator.unwrap_model(model).disable_adapter():
                    ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)

        # Compute the KL divergence between the model and the reference model
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
  • 计算提示的长度和完成信息的ID。
  • 定义get_per_token_logps函数,用于计算每个令牌的对数概率。
  • 计算模型和参考模型的每个令牌的对数概率。
  • 计算模型和参考模型之间的KL散度。
7. 计算奖励
        # Compute the rewards
        prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]

        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        for i, (reward_func, reward_processing_class) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes)
        ):
            if isinstance(reward_func, PreTrainedModel):
                if is_conversational(inputs[0]):
                    messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                    texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                else:
                    texts = [p + c for p, c in zip(prompts, completions)]
                reward_inputs = reward_processing_class(
                    texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                )
                reward_inputs = super()._prepare_inputs(reward_inputs)
                with torch.inference_mode():
                    rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
            else:
                # Repeat all input columns (but "prompt" and "completion") to match the number of generations
                reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
                for key in reward_kwargs:
                    for example in inputs:
                        # Repeat each value in the column for `num_generations` times
                        reward_kwargs[key].extend([example[key]] * self.num_generations)
                output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
                rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # Sum the rewards from all reward functions
        rewards = rewards_per_func.sum(dim=1)
  • 重复提示信息以匹配生成次数。
  • 遍历每个奖励函数,根据奖励函数的类型(预训练模型或可调用对象)计算奖励。
  • 将所有奖励函数的奖励相加。
8. 计算优势和损失
        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # x - x.detach() allows for preserving gradients from x
        per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        per_token_loss = -(per_token_loss - self.beta * per_token_kl)
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
  • 计算分组奖励的均值和标准差。
  • 归一化奖励以计算优势。
  • 计算每个令牌的损失,并最终计算总损失。
9. 记录指标
        # Log the metrics
        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics["completion_length"].append(completion_length)

        reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                reward_func_name = reward_func.config._name_or_path.split("/")[-1]
            else:
                reward_func_name = reward_func.__name__
            self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

        self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())

        self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

        mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        return loss
  • 计算完成长度、每个奖励函数的奖励、总奖励、奖励标准差和KL散度等指标,并记录到self._metrics中。
  • 最后返回计算得到的损失。

grpo.py 文件解读

@dataclass
class GRPOScriptArguments(ScriptArguments):
    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        # default_factory=lambda: ["accuracy", ],
        metadata={
            "help": f"List of reward functions. Possible values: {', '.join(REWARD_FUNCS_REGISTRY.keys())}"
        },
    )
  • 定义了一个数据类 GRPOScriptArguments,继承自 ScriptArguments
    • reward_funcs 是一个字符串列表类型的字段,默认值为 ["accuracy", "format"]
    • metadata 中提供了该字段的帮助信息,列出了可能的奖励函数名称。
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within   and   tags, respectively, i.e., "
    " reasoning process here  answer here "
)
  • 定义了一个系统提示信息 SYSTEM_PROMPT,用于描述用户和助手之间的对话模式。
def main(script_args, training_args, model_args):
    # Set seed for reproducibility
    set_seed(training_args.seed)
  • 定义了主函数 main,接受 script_argstraining_argsmodel_args 作为参数。
    • 使用 set_seed 函数设置随机种子,以确保实验的可重复性。
    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
  • 配置日志记录:
    • 使用 logging.basicConfig 设置日志的格式和日期格式,并将日志输出到标准输出流。
    • 获取训练参数中的日志级别,并设置日志记录器的级别。
    • 设置 datasetstransformers 库的日志级别,并启用默认处理程序和显式格式。
    # Log on each process a small summary
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Script parameters {script_args}")
    logger.info(f"Data parameters {training_args}")
  • 记录一些关于进程、设备、GPU数量、分布式训练和16位训练的信息,以及模型参数、脚本参数和数据参数。
    # Check for last checkpoint
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
  • 检查是否有上一次的检查点:
    • 如果输出目录存在,则尝试获取最后一个检查点。
    • 如果找到检查点且没有指定从特定检查点恢复训练,则记录信息并准备从该检查点恢复训练。
    # Load the dataset
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
  • 使用 load_dataset 函数加载数据集,数据集的名称和配置由脚本参数指定。
    # Get reward functions
    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据脚本参数中的奖励函数名称,从奖励函数注册表 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数。
    # Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["problem"]},
            ],
        }

    dataset = dataset.map(make_conversation)
    for split in dataset:
        if "messages" in dataset[split].column_names:
            dataset[split] = dataset[split].remove_columns("messages")
  • 定义了一个函数 make_conversation,用于将数据集中的每个示例格式化为对话形式。
    • 使用 dataset.map 函数将该格式化函数应用到整个数据集。
    • 遍历数据集的每个分割,如果存在 messages 列,则将其移除。
    logger.info("*** Initializing model kwargs ***")
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )

    training_args.gradient_checkpointing = True
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
    )
  • 初始化模型的关键字参数:
    • 根据模型参数中的 torch_dtype 确定 torch_dtype 的值。
    • 启用梯度检查点,构建模型关键字参数字典。
    model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, load_in_4bit=False, **model_kwargs)

    print(model_args.model_name_or_path,)
  • 使用 AutoModelForCausalLM.from_pretrained 函数从预训练模型中加载模型,不使用4位量化加载。
    • 打印模型的名称或路径。
    #############################
    # Initialize the XGRPO trainer
    #############################
    trainer = XGRPOTrainer(
        # model=model_args.model_name_or_path,
        model = model,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        callbacks=get_callbacks(training_args, model_args),
    )
  • 初始化 XGRPOTrainer 训练器,传入模型、奖励函数、训练参数、训练数据集、评估数据集(如果评估策略不是 no)和回调函数。
    ###############
    # Training loop
    ###############
    logger.info("*** Train ***")
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
  • 开始训练循环:
    • 记录训练开始信息。
    • 确定是否从检查点恢复训练。
    • 调用训练器的 train 方法进行训练,获取训练结果。
    • 记录训练指标,保存指标和训练状态。
    ##################################
    # Save model and create model card
    ##################################
    logger.info("*** Save model ***")
    trainer.save_model(training_args.output_dir)
    logger.info(f"Model saved to {training_args.output_dir}")

    # Save everything else on main process
    kwargs = {
        "dataset_name": script_args.dataset_name,
        "tags": ["X-R1"],
    }
    if trainer.accelerator.is_main_process:
        trainer.create_model_card(**kwargs)
        # Restore k,v cache for fast inference
        trainer.model.config.use_cache = True
        trainer.model.config.save_pretrained(training_args.output_dir)
  • 保存模型和创建模型卡片:
    • 记录保存模型信息,调用训练器的 save_model 方法保存模型。
    • 在主进程中创建模型卡片,恢复模型的缓存并保存模型配置。
if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)
  • 如果脚本作为主程序运行,则创建 TrlParser 解析器,解析命令行参数和配置,然后调用 main 函数开始执行。

rewards.py 文件解读

"""Reward functions for GRPO training."""
"""Reward functions for GRPO training."""

import re

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
  • 文件开头的注释表明该文件包含用于GRPO训练的奖励函数。
    • 导入 re 模块用于正则表达式操作。
    • latex2sympy2_extended 导入 NormalizationConfig,从 math_verify 导入 LatexExtractionConfigparseverify 函数。
def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) != 0:
            # print('latex gold parsed')
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            reward = float(verify(answer_parsed, gold_parsed))
            # print('\nprompt:', prompt)
            print('-'*100)
            print('\nanswer_parsed:', answer_parsed, '\ngold_parsed:', gold_parsed, '\nreward:', reward)
        else:
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    print('\naccuracy rewards:', rewards)

    return rewards
  • 定义了一个名为 accuracy_reward 的奖励函数,用于检查生成的答案是否与真实答案相同。
    • 提取每个完成结果的内容。
    • 遍历每个内容和对应的真实答案,尝试解析真实答案和生成的答案。
    • 如果解析成功,则使用 verify 函数验证两者是否相同,相同则奖励为1,否则为0。
    • 如果解析失败,则奖励为1,并打印错误信息。
    • 最后返回奖励列表。
def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^.*?.*?$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]

    rewards = [1.0 if match else 0.0 for match in matches]
    print('-'*100)
    print('\nformat rewards:', rewards)
    return rewards
  • 定义了一个名为 format_reward 的奖励函数,用于检查生成的答案是否符合特定格式。
    • 定义了一个正则表达式模式,用于匹配 ... ... 格式。
    • 提取每个完成结果的内容,使用正则表达式进行匹配。
    • 根据匹配结果给出奖励,匹配成功为1,失败为0。
    • 打印奖励列表并返回。
def reasoning_steps_reward(completions, **kwargs):
    """Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
    """
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [len(re.findall(pattern, content)) for content in completion_contents]

    # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
    return [min(1.0, count / 3) for count in matches]
  • 定义了一个名为 reasoning_steps_reward 的奖励函数,用于检查生成的答案是否包含清晰的推理步骤。
    • 定义了一个正则表达式模式,用于匹配步骤编号、列表项和过渡词。
    • 提取每个完成结果的内容,统计匹配到的模式数量。
    • 根据匹配数量计算奖励,鼓励至少有3个步骤,奖励最大为1。
REWARD_FUNCS_REGISTRY = {
    "accuracy": accuracy_reward,
    "format": format_reward,
    "reasoning_steps": reasoning_steps_reward,
}
  • 定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,将奖励函数名称映射到对应的函数。

你可能感兴趣的:(人工智能,人工智能,深度学习,学习)