这个项目名为[X-R1](https://github.com/dhcode-cpp/X-R1)
,是一个基于强化学习的训练框架,旨在构建一个易于使用、低成本的训练框架,以加速Scaling Post-Training的开发。以下是对该项目的详细解释:
项目的主要目录结构如下:
X-R1/
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── requirements.txt
├── setup.cfg
├── setup.py
├── src/
│ └── x_r1/
├── README.assets/
│ ├── X-R1-0.5B-acc-result.png
│ ├── X-R1-log.png
│ └── aha_moment_0.5B.png
└── recipes/
├── README.md
├── X_R1_test_env_single.yaml
├── X_R1_zero_0dot5B_config.yaml
├── X_R1_zero_1dot5B_config.yaml
├── X_R1_zero_7B_config.yaml
├── zero1.yaml
├── zero2.yaml
└── zero3.yaml
pip install -r requirements.txt
可以安装所有依赖。X_R1_test_env_single.yaml
、X_R1_zero_0dot5B_config.yaml
等。低成本训练
数据集支持
日志记录
项目中的配置文件(如X_R1_zero_0dot5B_config.yaml
等)主要包含以下几部分的配置:
model_name_or_path
、model_revision
、torch_dtype
等,指定了模型的名称、版本和数据类型。dataset_name
、dataset_configs
、num_processes
等,指定了训练使用的数据集和进程数。use_vllm
、output_dir
、gradient_accumulation_steps
等,配置了训练器的相关参数。xr1
的Python 3.11虚拟环境:conda create -n xr1 python=3.11
conda activate xr1
- 安装项目所需的依赖库:
pip install -r requirements.txt
- 创建输出目录:
mkdir output
ACCELERATE_LOG_LEVEL=info \
accelerate launch \
--config_file recipes/zero1.yaml \
--num_processes=1 \
src/x_r1/grpo.py \
--config recipes/X_R1_test_env_single.yaml \
> ./output/x_r1_test_sampling.log 2>&1
- **多GPU运行**:
ACCELERATE_LOG_LEVEL=info \
accelerate launch \
--config_file recipes/accelerate_configs/zero3.yaml \
--num_processes=1 \
src/x_r1/grpo.py \
--config recipes/x_r1_test_sampling.yaml \
> ./output/test.log 2>&1
register_lighteval_task
函数:位于X-R1/src/x_r1/utils/evaluation.py
文件中,用于注册LightEval任务配置。push_to_hub_revision
函数:位于X-R1/src/x_r1/utils/hub.py
文件中,用于将模型推送到Hub仓库的指定分支。项目受到了DeepSeek-R1和open-r1的启发。
x_grpo_trainer.py
剖析import os
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather_object
from transformers import (
PreTrainedModel,
Trainer,
)
from trl.trainer import GRPOTrainer
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import pad
os
、torch
、transformers
等,这些模块用于文件操作、深度学习计算、数据处理等。trl
库中导入了GRPOTrainer
、GRPOConfig
等相关类和函数,表明该文件可能是在trl
库的基础上进行扩展。# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
RewardFunc
类型别名,它可以是字符串(模型ID)、预训练模型或一个可调用对象(接受提示和完成列表并返回奖励列表)。XGRPOTrainer
类定义class XGRPOTrainer(GRPOTrainer):
# base trl GRPO_trainer
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
XGRPOTrainer
类继承自GRPOTrainer
,并重写了compute_loss
方法。return_outputs
为True
,会抛出ValueError
,表明该训练器不支持返回输出。 device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
processing_class
对提示进行处理,生成输入张量。max_prompt_length
,则对输入进行截断。 # Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
for output in outputs:
print('-'*100)
print('\n\n\n')
prompt = output.prompt
for output_t in output.outputs:
# print(completion_ids)
print('='*100)
generated_text = output_t.text
print("【USER】: ", prompt )
print("\n【ASSISTANT】:", generated_text)
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0).to(device)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_inputs['input_ids'].to(device)
prompt_inputs['attention_mask'].to(device)
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
self.args.use_vllm
的值,选择使用vLLM
或常规生成方法生成完成信息。vLLM
,需要在主进程中加载权重,并在主进程中生成完成信息,然后将结果广播到所有进程。 prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
get_per_token_logps
函数,用于计算每个令牌的对数概率。 # Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss
self._metrics
中。grpo.py
文件解读@dataclass
class GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
# default_factory=lambda: ["accuracy", ],
metadata={
"help": f"List of reward functions. Possible values: {', '.join(REWARD_FUNCS_REGISTRY.keys())}"
},
)
GRPOScriptArguments
,继承自 ScriptArguments
。
reward_funcs
是一个字符串列表类型的字段,默认值为 ["accuracy", "format"]
。metadata
中提供了该字段的帮助信息,列出了可能的奖励函数名称。SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within and tags, respectively, i.e., "
" reasoning process here answer here "
)
SYSTEM_PROMPT
,用于描述用户和助手之间的对话模式。def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
main
,接受 script_args
、training_args
和 model_args
作为参数。
set_seed
函数设置随机种子,以确保实验的可重复性。 ###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logging.basicConfig
设置日志的格式和日期格式,并将日志输出到标准输出流。datasets
和 transformers
库的日志级别,并启用默认处理程序和显式格式。 # Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
load_dataset
函数加载数据集,数据集的名称和配置由脚本参数指定。 # Get reward functions
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
REWARD_FUNCS_REGISTRY
中获取相应的奖励函数。 # Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
make_conversation
,用于将数据集中的每个示例格式化为对话形式。
dataset.map
函数将该格式化函数应用到整个数据集。messages
列,则将其移除。 logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
training_args.gradient_checkpointing = True
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
torch_dtype
确定 torch_dtype
的值。 model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, load_in_4bit=False, **model_kwargs)
print(model_args.model_name_or_path,)
AutoModelForCausalLM.from_pretrained
函数从预训练模型中加载模型,不使用4位量化加载。
#############################
# Initialize the XGRPO trainer
#############################
trainer = XGRPOTrainer(
# model=model_args.model_name_or_path,
model = model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
callbacks=get_callbacks(training_args, model_args),
)
XGRPOTrainer
训练器,传入模型、奖励函数、训练参数、训练数据集、评估数据集(如果评估策略不是 no
)和回调函数。 ###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
train
方法进行训练,获取训练结果。 ##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["X-R1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
save_model
方法保存模型。if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
TrlParser
解析器,解析命令行参数和配置,然后调用 main
函数开始执行。rewards.py
文件解读"""Reward functions for GRPO training."""
"""Reward functions for GRPO training."""
import re
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
re
模块用于正则表达式操作。latex2sympy2_extended
导入 NormalizationConfig
,从 math_verify
导入 LatexExtractionConfig
、parse
和 verify
函数。def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# print('latex gold parsed')
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
# print('\nprompt:', prompt)
print('-'*100)
print('\nanswer_parsed:', answer_parsed, '\ngold_parsed:', gold_parsed, '\nreward:', reward)
else:
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
print('\naccuracy rewards:', rewards)
return rewards
accuracy_reward
的奖励函数,用于检查生成的答案是否与真实答案相同。
verify
函数验证两者是否相同,相同则奖励为1,否则为0。def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^.*? .*? $"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
rewards = [1.0 if match else 0.0 for match in matches]
print('-'*100)
print('\nformat rewards:', rewards)
return rewards
format_reward
的奖励函数,用于检查生成的答案是否符合特定格式。
... ...
格式。def reasoning_steps_reward(completions, **kwargs):
"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
\n- - matches bullet points with hyphens
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [len(re.findall(pattern, content)) for content in completion_contents]
# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]
reasoning_steps_reward
的奖励函数,用于检查生成的答案是否包含清晰的推理步骤。
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
}
REWARD_FUNCS_REGISTRY
,将奖励函数名称映射到对应的函数。