推理型大语言模型现在确实火了。这类模型的特点是会先对问题做充分思考,然后再给出答案,而不是直接回复。
虽然早期训练推理型 LLM 的方法多半被各家公司当作核心机密,但最近的DeepSeek-R1、DeepSeekMath、Kimi-k1.5 和 DAPO 这些项目都公开了相关流程。
这些方法让 LLM 在推理过程中生成更长的思维链(Chain-of-Thought,CoT)输出,推理效果因此得到提升。同时它们还引入了改进的强化学习算法,比如 GRPO 和 DAPO,这些算法是对 OpenAI 最初 PPO 方法的高效升级。
这篇文章会先介绍 GRPO(Group Relative Policy Optimization,组相对策略优化)的基本概念,这是目前训练推理型 LLM 最常用的强化学习算法之一。然后我们会动手写代码训练一个推理 LLM,在实践中理解整个流程。
RLHF 与 PPO 的简单回顾
想理解 GRPO,我们得先回到最初用来对齐 LLM 的强化学习算法。
这个算法叫 Proximal Policy Optimization(PPO,近端策略优化),用它将 LLM 对齐到人类偏好的过程叫做 Reinforcement Learning From Human Feedback(RLHF,基于人类反馈的强化学习)。
LLM 对齐与 RLHF 可视化
RLHF 主要包含三个步骤:
步骤 1:训练监督微调策略
从预训练的 LLM 开始,在包含提示与人工编写回答的数据集上做微调。这样得到的模型在 RL 术语中叫"监督策略",它能针对给定提示生成更符合人类偏好的回答。
步骤 2:训练奖励模型
为每个提示收集多个模型输出,让人工标注者对这些输出排序,判断哪个更好。然后用这些数据训练一个"奖励模型",它会对给定输出返回一个标量分数作为人类偏好的代理。
步骤 3:用奖励模型做强化学习
从步骤 1 的监督策略复制一份作为"训练策略",同时保留一份冻结副本叫"参考策略"。给训练策略输入提示,用奖励模型对输出打分,然后用 PPO 基于这个奖励继续微调训练策略。
在 LLM 的强化学习中,"状态"指模型到某个时刻已生成的所有 token(也就是"上下文"),"动作"是下一个要预测的 token。
PPO 训练 LLM 时会用一个"价值模型"(通常从奖励模型初始化)来估计从给定状态出发的未来总期望奖励,叫做"Value"。接着用这个 Value 计算"优势"(Advantage,使用 GAE),它衡量在某状态下采取某动作相对于训练策略期望行为的好坏程度。
PPO 更新训练策略时就用到这个 Advantage。同时价值模型也会在训练过程中不断更新,以便在每个训练步提供更好的未来总期望奖励估计。
PPO 可视化,其中 Q 为查询,O 为训练策略的输出,KL 为训练策略模型与参考模型之间的 KL 散度,R 为奖励,V 为价值,A 为优势。
从 PPO 到 GRPO
GRPO 最初由 DeepSeekMath 论文提出,现在广泛用于训练推理型 LLM。
GRPO 和 PPO 的主要区别是:GRPO 不用价值模型来估计 Advantage。
它通过对同一提示下模型生成的一组输出进行相对打分来计算 Advantage,这也就是 GRPO 中"相对"这词的来源。
PPO 关注的是某个输出是否比价值模型的期望更好。
GRPO 关注的是某个输出是否比同一提示下所有输出的平均水平更好,这个平均值就作为价值的基线或代理。
GRPO 可视化,其中 Q 为查询,O(1..G) 为训练策略的多条输出,KL 为训练策略模型与参考模型之间的 KL 散度,R(1..G) 为每条输出对应的奖励,A(1..G) 为每条输出对应的优势。
用 GRPO 训练推理 LLM
本文的的所有代码都可以在 Google Colaboratory 笔记本中完成,运行环境用的是免费层的 T4 GPU。
基础模型选择 Qwen2.5–3B-Instruct(指令微调版)。
我们用 Unsloth——一个开源 Python 库和平台,专门用来优化和加速 LLM 微调。Unsloth 的好处是你只需定义奖励和训练配置,它会在内部管理参考策略与训练策略以及所有 GPU 操作。这大大简化了 GRPO 训练流水线。
下面和 Unsloth 相关的函数参数基本都不言自明,如果第一次接触可以查阅 Unsloth 文档。
安装依赖
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # 获取额外 30% 的上下文长度
# 安装依赖
!pip install unsloth_zoo
!pip install — upgrade unsloth vllm==0.9.2 numpy torchvision bitsandbytes xformers
!pip install triton==3.2.0
!pip install transformers==4.55.4
!pip install — no-deps trl==0.22.2
加载模型和分词器,这里加载 Qwen2.5–3B-Instruct 模型及其分词器。
from unsloth import FastLanguageModel
import torch
# 上下文长度
max_seq_length = 1024
# 加载模型与分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen2.5-3B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # 启用 4-bit 量化
fast_inference = True, # 启用 vLLM 快速推理
max_lora_rank = 8,
gpu_memory_utilization = 0.9,
)
用 LoRA 做参数高效微调,由于算力资源有限,我们不会训练 LLM 的全部参数,而是用 LoRA(低秩适配)来提升训练效率。
# 使用 LoRA 进行参数高效微调
model = FastLanguageModel.get_peft_model(
model,
r = 8,
# 需要微调的模块
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 8,
use_gradient_checkpointing = "unsloth",
random_state = 1234,
)
用著名的 GSM8K 数据集(小学到初中难度的数学文字题集合)来训练模型的推理能力。下面对数据集中的题目做格式化处理,以便用于训练。
import re
from datasets import load_dataset, Dataset
# 系统提示词
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 包裹推理与答案的模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# 从模型输出中抽取 <answer>...</answer> 内文本的函数
def extract_xml_answer(text):
if "<answer>" not in text or "</answer>" not in text:
return ""
return text.split("<answer>", 1)[-1].split("</answer>", 1)[0].strip()
# 从 GSM8K 标签中抽取正确答案,标签形如 '... #### final_answer'
def extract_hash_answer(text):
return text.split("####")[-1].strip() if "####" in text else None
# 加载 GSM8K 数据集并格式化为对话式提示的函数
def get_gsm8k_dataset(split = "train"):
data = load_dataset("openai/gsm8k", "main")[split]
return data.map(
lambda x: {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question"]},
],
"answer": extract_hash_answer(x["answer"]),
}
)
dataset = get_gsm8k_dataset()
下面定义用来评估推理模型训练效果的奖励函数。
# 奖励函数:检查从补全中抽取的答案
# 是否与给定的真实答案完全一致。
# 一致则返回 2.0,否则返回 0.0。
def correctness_reward_func(prompts, completions, answer, **kwargs):
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# 奖励函数:检查抽取的回答是否为整数。
# 若为数字则返回 0.5,否则返回 0.0。
def int_reward_func(completions, **kwargs):
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# 奖励函数:强约束的 XML 格式检查,
# 要求响应必须严格匹配以下结构:
# <reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n
# 格式正确返回 0.5,否则返回 0.0。
def strict_format_reward_func(completions, **kwargs):
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 奖励函数:较宽松的 XML 格式检查:
# 响应需包含 <reasoning>...</reasoning> 与 <answer>...</answer>,
# 但允许空格与换行的灵活性。
# 匹配返回 0.5,否则返回 0.0。
def soft_format_reward_func(completions, **kwargs):
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 辅助函数:统计并为 XML 标签计分
def count_xml(text):
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
# 奖励函数:将 count_xml 应用于补全输出。
# 根据标签正确性进行 XML 结构计分,并对尾部冗余内容施加惩罚。
def xmlcount_reward_func(completions, **kwargs):
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
定义 GRPO 训练器的参数。
from trl import GRPOConfig, GRPOTrainer
# 训练参数
training_args = GRPOConfig(
use_vllm = True, # 使用 vLLM 进行快速推理
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_8bit",
logging_steps = 1,
per_device_train_batch_size = 4,
gradient_accumulation_steps = 1,
num_generations = 4,
max_prompt_length = 256,
max_completion_length = 200,
max_steps = 250,
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none",
output_dir = "outputs",
)
# GRPO 训练器
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
用下面的命令启动训练。
# 开始训练
trainer.train()
强化学习模型工作原理是探索解空间,所以训练通常比较慢。LLM 可能需要数百步才能学会更好的推理,这意味着你需要等几个小时才能得到不错的结果。
上图为部分训练日志
保存模型有很多方式,我们主要关心的是保存 LoRA 适配器。
# 保存 LoRA 适配器
model.save_lora("grpo_saved_lora")
最后对比一下训练前后模型的输出。
from vllm import SamplingParams
# 训练前的模型推理
query = "How many r's are in strawberry?"
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
print(output)
训练前模型的输出:
There are 2 r's in the word "strawberry."
接下来试试经过 GRPO 训练的模型:
# 训练后的模型推理
text = tokenizer.apply_chat_template([
{"role" : "system", "content" : SYSTEM_PROMPT},
{"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
text,
sampling_params = sampling_params,
lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
print(output)
训练后模型的输出:
<reasoning>
To find out how many times the letter 'r' appears in the word "strawberry", we can go through the word character by character and count each occurrence of 'r'. In "strawberry", the letter 'r' appears 3 times: once in the beginning, once in the middle, and once at the end of the word.
</reasoning>
<answer>
3
</answer>
效果相当不错!可以看到模型现在会在回答问题前先进行推理,并且给出了正确答案。
下面是使用 GRPO 训练 Qwen 2.5(3B)训练过程的概览示意图:
本文的完整代码:
https://avoid.overfit.cn/post/1506330de8e349eab552ec1000417a27
作者:Dr. Ashish Bamania