从零开始训练推理模型:GRPO+Unsloth改造Qwen实战指南

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,1000CU*H 3个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: 推理型大语言模型兴起,通过先思考再作答提升性能。本文介绍GRPO等强化学习算法,详解其原理并动手用Qwen2.5-3B训练推理模型,展示训练前后效果对比,揭示思维链生成的实现路径。

推理型大语言模型现在确实火了。这类模型的特点是会先对问题做充分思考,然后再给出答案,而不是直接回复。

虽然早期训练推理型 LLM 的方法多半被各家公司当作核心机密,但最近的DeepSeek-R1、DeepSeekMath、Kimi-k1.5 和 DAPO 这些项目都公开了相关流程。

这些方法让 LLM 在推理过程中生成更长的思维链(Chain-of-Thought,CoT)输出,推理效果因此得到提升。同时它们还引入了改进的强化学习算法,比如 GRPO 和 DAPO,这些算法是对 OpenAI 最初 PPO 方法的高效升级。

这篇文章会先介绍 GRPO(Group Relative Policy Optimization,组相对策略优化)的基本概念,这是目前训练推理型 LLM 最常用的强化学习算法之一。然后我们会动手写代码训练一个推理 LLM,在实践中理解整个流程。

RLHF 与 PPO 的简单回顾

想理解 GRPO,我们得先回到最初用来对齐 LLM 的强化学习算法。

这个算法叫 Proximal Policy Optimization(PPO,近端策略优化),用它将 LLM 对齐到人类偏好的过程叫做 Reinforcement Learning From Human Feedback(RLHF,基于人类反馈的强化学习)。

LLM 对齐与 RLHF 可视化

RLHF 主要包含三个步骤:

步骤 1:训练监督微调策略

从预训练的 LLM 开始,在包含提示与人工编写回答的数据集上做微调。这样得到的模型在 RL 术语中叫"监督策略",它能针对给定提示生成更符合人类偏好的回答。

步骤 2:训练奖励模型

为每个提示收集多个模型输出,让人工标注者对这些输出排序,判断哪个更好。然后用这些数据训练一个"奖励模型",它会对给定输出返回一个标量分数作为人类偏好的代理。

步骤 3:用奖励模型做强化学习

从步骤 1 的监督策略复制一份作为"训练策略",同时保留一份冻结副本叫"参考策略"。给训练策略输入提示,用奖励模型对输出打分,然后用 PPO 基于这个奖励继续微调训练策略。

在 LLM 的强化学习中,"状态"指模型到某个时刻已生成的所有 token(也就是"上下文"),"动作"是下一个要预测的 token。

PPO 训练 LLM 时会用一个"价值模型"(通常从奖励模型初始化)来估计从给定状态出发的未来总期望奖励,叫做"Value"。接着用这个 Value 计算"优势"(Advantage,使用 GAE),它衡量在某状态下采取某动作相对于训练策略期望行为的好坏程度。

PPO 更新训练策略时就用到这个 Advantage。同时价值模型也会在训练过程中不断更新,以便在每个训练步提供更好的未来总期望奖励估计。

PPO 可视化,其中 Q 为查询,O 为训练策略的输出,KL 为训练策略模型与参考模型之间的 KL 散度,R 为奖励,V 为价值,A 为优势。

从 PPO 到 GRPO

GRPO 最初由 DeepSeekMath 论文提出,现在广泛用于训练推理型 LLM。

GRPO 和 PPO 的主要区别是:GRPO 不用价值模型来估计 Advantage。

它通过对同一提示下模型生成的一组输出进行相对打分来计算 Advantage,这也就是 GRPO 中"相对"这词的来源。

PPO 关注的是某个输出是否比价值模型的期望更好。

GRPO 关注的是某个输出是否比同一提示下所有输出的平均水平更好,这个平均值就作为价值的基线或代理。

GRPO 可视化,其中 Q 为查询,O(1..G) 为训练策略的多条输出,KL 为训练策略模型与参考模型之间的 KL 散度,R(1..G) 为每条输出对应的奖励,A(1..G) 为每条输出对应的优势。

用 GRPO 训练推理 LLM

本文的的所有代码都可以在 Google Colaboratory 笔记本中完成,运行环境用的是免费层的 T4 GPU。

基础模型选择 Qwen2.5–3B-Instruct(指令微调版)。

我们用 Unsloth——一个开源 Python 库和平台,专门用来优化和加速 LLM 微调。Unsloth 的好处是你只需定义奖励和训练配置,它会在内部管理参考策略与训练策略以及所有 GPU 操作。这大大简化了 GRPO 训练流水线。

下面和 Unsloth 相关的函数参数基本都不言自明,如果第一次接触可以查阅 Unsloth 文档。

安装依赖

 import os  
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # 获取额外 30% 的上下文长度  

# 安装依赖  
!pip install unsloth_zoo  
!pip install — upgrade unsloth vllm==0.9.2 numpy torchvision bitsandbytes xformers  
!pip install triton==3.2.0  
!pip install transformers==4.55.4  
 !pip install — no-deps trl==0.22.2

加载模型和分词器,这里加载 Qwen2.5–3B-Instruct 模型及其分词器。

 from unsloth import FastLanguageModel  
import torch  

# 上下文长度  
max_seq_length = 1024  

# 加载模型与分词器  
model, tokenizer = FastLanguageModel.from_pretrained(  
    model_name = "unsloth/Qwen2.5-3B-Instruct",  
    max_seq_length = max_seq_length,  
    load_in_4bit = True, # 启用 4-bit 量化  
    fast_inference = True, # 启用 vLLM 快速推理  
    max_lora_rank = 8,  
    gpu_memory_utilization = 0.9,  
 )

用 LoRA 做参数高效微调,由于算力资源有限,我们不会训练 LLM 的全部参数,而是用 LoRA(低秩适配)来提升训练效率。

 # 使用 LoRA 进行参数高效微调  
model = FastLanguageModel.get_peft_model(  
    model,  
    r = 8,   
    # 需要微调的模块  
    target_modules = [  
        "q_proj", "k_proj", "v_proj", "o_proj",  
        "gate_proj", "up_proj", "down_proj",  
    ],   
    lora_alpha = 8,  
    use_gradient_checkpointing = "unsloth",  
    random_state = 1234,  
 )

用著名的 GSM8K 数据集(小学到初中难度的数学文字题集合)来训练模型的推理能力。下面对数据集中的题目做格式化处理,以便用于训练。

 import re  
from datasets import load_dataset, Dataset  

# 系统提示词  
SYSTEM_PROMPT = """  
Respond in the following format:  
<reasoning>  
...  
</reasoning>  
<answer>  
...  
</answer>  
"""  

# 包裹推理与答案的模板  
XML_COT_FORMAT = """\  
<reasoning>  
{reasoning}  
</reasoning>  
<answer>  
{answer}  
</answer>  
"""  

# 从模型输出中抽取 <answer>...</answer> 内文本的函数  
def extract_xml_answer(text):  
    if "<answer>" not in text or "</answer>" not in text:  
        return ""  
    return text.split("<answer>", 1)[-1].split("</answer>", 1)[0].strip()  

# 从 GSM8K 标签中抽取正确答案,标签形如 '... #### final_answer'  
def extract_hash_answer(text):  
    return text.split("####")[-1].strip() if "####" in text else None  

# 加载 GSM8K 数据集并格式化为对话式提示的函数  
def get_gsm8k_dataset(split = "train"):  
    data = load_dataset("openai/gsm8k", "main")[split]  
    return data.map(  
        lambda x: {  
            "prompt": [  
                {"role": "system", "content": SYSTEM_PROMPT},  
                {"role": "user", "content": x["question"]},  
            ],  
            "answer": extract_hash_answer(x["answer"]),  
        }  
    )  

 dataset = get_gsm8k_dataset()

下面定义用来评估推理模型训练效果的奖励函数。

 # 奖励函数:检查从补全中抽取的答案  
# 是否与给定的真实答案完全一致。  
# 一致则返回 2.0,否则返回 0.0。  
def correctness_reward_func(prompts, completions, answer, **kwargs):  
    responses = [completion[0]['content'] for completion in completions]  
    q = prompts[0][-1]['content']  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")  
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]  

# 奖励函数:检查抽取的回答是否为整数。  
# 若为数字则返回 0.5,否则返回 0.0。  
def int_reward_func(completions, **kwargs):  
    responses = [completion[0]['content'] for completion in completions]  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]  

# 奖励函数:强约束的 XML 格式检查,  
# 要求响应必须严格匹配以下结构:  
# <reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n  
# 格式正确返回 0.5,否则返回 0.0。  
def strict_format_reward_func(completions, **kwargs):  
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  

# 奖励函数:较宽松的 XML 格式检查:  
# 响应需包含 <reasoning>...</reasoning> 与 <answer>...</answer>,  
# 但允许空格与换行的灵活性。  
# 匹配返回 0.5,否则返回 0.0。  
def soft_format_reward_func(completions, **kwargs):  
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  

# 辅助函数:统计并为 XML 标签计分  
def count_xml(text):  
    count = 0.0  
    if text.count("<reasoning>\n") == 1:  
        count += 0.125  
    if text.count("\n</reasoning>\n") == 1:  
        count += 0.125  
    if text.count("\n<answer>\n") == 1:  
        count += 0.125  
        count -= len(text.split("\n</answer>\n")[-1])*0.001  
    if text.count("\n</answer>") == 1:  
        count += 0.125  
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001  
    return count  

# 奖励函数:将 count_xml 应用于补全输出。  
# 根据标签正确性进行 XML 结构计分,并对尾部冗余内容施加惩罚。  
def xmlcount_reward_func(completions, **kwargs):  
    contents = [completion[0]["content"] for completion in completions]  
     return [count_xml(c) for c in contents]

定义 GRPO 训练器的参数。

 from trl import GRPOConfig, GRPOTrainer  

# 训练参数  
training_args = GRPOConfig(  
    use_vllm = True, # 使用 vLLM 进行快速推理  
    learning_rate = 5e-6,  
    adam_beta1 = 0.9,  
    adam_beta2 = 0.99,  
    weight_decay = 0.1,  
    warmup_ratio = 0.1,  
    lr_scheduler_type = "cosine",  
    optim = "adamw_8bit",  
    logging_steps = 1,  
    per_device_train_batch_size = 4,  
    gradient_accumulation_steps = 1,   
    num_generations = 4,   
    max_prompt_length = 256,  
    max_completion_length = 200,  
    max_steps = 250,  
    save_steps = 250,  
    max_grad_norm = 0.1,  
    report_to = "none",  
    output_dir = "outputs",  
)


# GRPO 训练器   
trainer = GRPOTrainer(  
    model = model,  
    processing_class = tokenizer,  
    reward_funcs = [  
        xmlcount_reward_func,  
        soft_format_reward_func,  
        strict_format_reward_func,  
        int_reward_func,  
        correctness_reward_func,  
    ],  
    args = training_args,  
    train_dataset = dataset,  
 )

用下面的命令启动训练。

 # 开始训练   
 trainer.train()

强化学习模型工作原理是探索解空间,所以训练通常比较慢。LLM 可能需要数百步才能学会更好的推理,这意味着你需要等几个小时才能得到不错的结果。

上图为部分训练日志

保存模型有很多方式,我们主要关心的是保存 LoRA 适配器。

 # 保存 LoRA 适配器  
 model.save_lora("grpo_saved_lora")

最后对比一下训练前后模型的输出。

from vllm import SamplingParams  

# 训练前的模型推理  
query = "How many r's are in strawberry?"  

text = tokenizer.apply_chat_template([  
    {"role" : "user", "content" : query},  
], tokenize = False, add_generation_prompt = True)  

sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  

output = model.fast_generate(  
    [text],  
    sampling_params = sampling_params,  
    lora_request = None,  
)[0].outputs[0].text  

print(output)

训练前模型的输出:

There are 2 r's in the word "strawberry."

接下来试试经过 GRPO 训练的模型:

# 训练后的模型推理  
text = tokenizer.apply_chat_template([  
    {"role" : "system", "content" : SYSTEM_PROMPT},  
    {"role" : "user", "content" : query},  
], tokenize = False, add_generation_prompt = True)  

sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  

output = model.fast_generate(  
    text,  
    sampling_params = sampling_params,  
    lora_request = model.load_lora("grpo_saved_lora"),  
)[0].outputs[0].text  

print(output)

训练后模型的输出:

<reasoning>  
To find out how many times the letter 'r' appears in the word "strawberry", we can go through the word character by character and count each occurrence of 'r'. In "strawberry", the letter 'r' appears 3 times: once in the beginning, once in the middle, and once at the end of the word.  
</reasoning>  
<answer>  
3  
</answer>

效果相当不错!可以看到模型现在会在回答问题前先进行推理,并且给出了正确答案。

下面是使用 GRPO 训练 Qwen 2.5(3B)训练过程的概览示意图:

本文的完整代码:

https://avoid.overfit.cn/post/1506330de8e349eab552ec1000417a27
作者:Dr. Ashish Bamania

目录
相关文章
|
4天前
|
弹性计算 关系型数据库 微服务
基于 Docker 与 Kubernetes(K3s)的微服务:阿里云生产环境扩容实践
在微服务架构中,如何实现“稳定扩容”与“成本可控”是企业面临的核心挑战。本文结合 Python FastAPI 微服务实战,详解如何基于阿里云基础设施,利用 Docker 封装服务、K3s 实现容器编排,构建生产级微服务架构。内容涵盖容器构建、集群部署、自动扩缩容、可观测性等关键环节,适配阿里云资源特性与服务生态,助力企业打造低成本、高可靠、易扩展的微服务解决方案。
1106 0
|
3天前
|
机器学习/深度学习 人工智能 前端开发
通义DeepResearch全面开源!同步分享可落地的高阶Agent构建方法论
通义研究团队开源发布通义 DeepResearch —— 首个在性能上可与 OpenAI DeepResearch 相媲美、并在多项权威基准测试中取得领先表现的全开源 Web Agent。
531 10
|
13天前
|
人工智能 运维 安全
|
12天前
|
人工智能 测试技术 API
智能体(AI Agent)搭建全攻略:从概念到实践的终极指南
在人工智能浪潮中,智能体(AI Agent)正成为变革性技术。它们具备自主决策、环境感知、任务执行等能力,广泛应用于日常任务与商业流程。本文详解智能体概念、架构及七步搭建指南,助你打造专属智能体,迎接智能自动化新时代。
|
4天前
|
弹性计算 Kubernetes jenkins
如何在 ECS/EKS 集群中有效使用 Jenkins
本文探讨了如何将 Jenkins 与 AWS ECS 和 EKS 集群集成,以构建高效、灵活且具备自动扩缩容能力的 CI/CD 流水线,提升软件交付效率并优化资源成本。
301 0
|
11天前
|
人工智能 异构计算
敬请锁定《C位面对面》,洞察通用计算如何在AI时代持续赋能企业创新,助力业务发展!
敬请锁定《C位面对面》,洞察通用计算如何在AI时代持续赋能企业创新,助力业务发展!
|
12天前
|
机器学习/深度学习 人工智能 自然语言处理
B站开源IndexTTS2,用极致表现力颠覆听觉体验
在语音合成技术不断演进的背景下,早期版本的IndexTTS虽然在多场景应用中展现出良好的表现,但在情感表达的细腻度与时长控制的精准性方面仍存在提升空间。为了解决这些问题,并进一步推动零样本语音合成在实际场景中的落地能力,B站语音团队对模型架构与训练策略进行了深度优化,推出了全新一代语音合成模型——IndexTTS2 。
807 23
|
4天前
|
缓存 供应链 监控
VVIC seller_search 排行榜搜索接口深度分析及 Python 实现
VVIC搜款网seller_search接口提供服装批发市场的商品及商家排行榜数据,涵盖热销榜、销量排名、类目趋势等,支持多维度筛选与数据分析,助力选品决策、竞品分析与市场预测,为服装供应链提供有力数据支撑。
|
4天前
|
缓存 监控 API
Amazon item_review 商品评论接口深度分析及 Python 实现
亚马逊商品评论接口(item_review)可获取用户评分、评论内容及时间等数据,支持多维度筛选与分页调用,结合Python实现情感分析、关键词提取与可视化,助力竞品分析、产品优化与市场决策。