从零开始训练推理模型:GRPO+Unsloth改造Qwen实战指南

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,1000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
简介: 推理型大语言模型兴起,通过先思考再作答提升性能。本文介绍GRPO等强化学习算法,详解其原理并动手用Qwen2.5-3B训练推理模型,展示训练前后效果对比,揭示思维链生成的实现路径。

推理型大语言模型现在确实火了。这类模型的特点是会先对问题做充分思考,然后再给出答案,而不是直接回复。

虽然早期训练推理型 LLM 的方法多半被各家公司当作核心机密,但最近的DeepSeek-R1、DeepSeekMath、Kimi-k1.5 和 DAPO 这些项目都公开了相关流程。

这些方法让 LLM 在推理过程中生成更长的思维链(Chain-of-Thought,CoT)输出,推理效果因此得到提升。同时它们还引入了改进的强化学习算法,比如 GRPO 和 DAPO,这些算法是对 OpenAI 最初 PPO 方法的高效升级。

这篇文章会先介绍 GRPO(Group Relative Policy Optimization,组相对策略优化)的基本概念,这是目前训练推理型 LLM 最常用的强化学习算法之一。然后我们会动手写代码训练一个推理 LLM,在实践中理解整个流程。

RLHF 与 PPO 的简单回顾

想理解 GRPO,我们得先回到最初用来对齐 LLM 的强化学习算法。

这个算法叫 Proximal Policy Optimization(PPO,近端策略优化),用它将 LLM 对齐到人类偏好的过程叫做 Reinforcement Learning From Human Feedback(RLHF,基于人类反馈的强化学习)。

LLM 对齐与 RLHF 可视化

RLHF 主要包含三个步骤:

步骤 1:训练监督微调策略

从预训练的 LLM 开始,在包含提示与人工编写回答的数据集上做微调。这样得到的模型在 RL 术语中叫"监督策略",它能针对给定提示生成更符合人类偏好的回答。

步骤 2:训练奖励模型

为每个提示收集多个模型输出,让人工标注者对这些输出排序,判断哪个更好。然后用这些数据训练一个"奖励模型",它会对给定输出返回一个标量分数作为人类偏好的代理。

步骤 3:用奖励模型做强化学习

从步骤 1 的监督策略复制一份作为"训练策略",同时保留一份冻结副本叫"参考策略"。给训练策略输入提示,用奖励模型对输出打分,然后用 PPO 基于这个奖励继续微调训练策略。

在 LLM 的强化学习中,"状态"指模型到某个时刻已生成的所有 token(也就是"上下文"),"动作"是下一个要预测的 token。

PPO 训练 LLM 时会用一个"价值模型"(通常从奖励模型初始化)来估计从给定状态出发的未来总期望奖励,叫做"Value"。接着用这个 Value 计算"优势"(Advantage,使用 GAE),它衡量在某状态下采取某动作相对于训练策略期望行为的好坏程度。

PPO 更新训练策略时就用到这个 Advantage。同时价值模型也会在训练过程中不断更新,以便在每个训练步提供更好的未来总期望奖励估计。

PPO 可视化,其中 Q 为查询,O 为训练策略的输出,KL 为训练策略模型与参考模型之间的 KL 散度,R 为奖励,V 为价值,A 为优势。

从 PPO 到 GRPO

GRPO 最初由 DeepSeekMath 论文提出,现在广泛用于训练推理型 LLM。

GRPO 和 PPO 的主要区别是:GRPO 不用价值模型来估计 Advantage。

它通过对同一提示下模型生成的一组输出进行相对打分来计算 Advantage,这也就是 GRPO 中"相对"这词的来源。

PPO 关注的是某个输出是否比价值模型的期望更好。

GRPO 关注的是某个输出是否比同一提示下所有输出的平均水平更好,这个平均值就作为价值的基线或代理。

GRPO 可视化,其中 Q 为查询,O(1..G) 为训练策略的多条输出,KL 为训练策略模型与参考模型之间的 KL 散度,R(1..G) 为每条输出对应的奖励,A(1..G) 为每条输出对应的优势。

用 GRPO 训练推理 LLM

本文的的所有代码都可以在 Google Colaboratory 笔记本中完成,运行环境用的是免费层的 T4 GPU。

基础模型选择 Qwen2.5–3B-Instruct(指令微调版)。

我们用 Unsloth——一个开源 Python 库和平台,专门用来优化和加速 LLM 微调。Unsloth 的好处是你只需定义奖励和训练配置,它会在内部管理参考策略与训练策略以及所有 GPU 操作。这大大简化了 GRPO 训练流水线。

下面和 Unsloth 相关的函数参数基本都不言自明,如果第一次接触可以查阅 Unsloth 文档。

安装依赖

 import os  
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # 获取额外 30% 的上下文长度  

# 安装依赖  
!pip install unsloth_zoo  
!pip install — upgrade unsloth vllm==0.9.2 numpy torchvision bitsandbytes xformers  
!pip install triton==3.2.0  
!pip install transformers==4.55.4  
 !pip install — no-deps trl==0.22.2

加载模型和分词器,这里加载 Qwen2.5–3B-Instruct 模型及其分词器。

 from unsloth import FastLanguageModel  
import torch  

# 上下文长度  
max_seq_length = 1024  

# 加载模型与分词器  
model, tokenizer = FastLanguageModel.from_pretrained(  
    model_name = "unsloth/Qwen2.5-3B-Instruct",  
    max_seq_length = max_seq_length,  
    load_in_4bit = True, # 启用 4-bit 量化  
    fast_inference = True, # 启用 vLLM 快速推理  
    max_lora_rank = 8,  
    gpu_memory_utilization = 0.9,  
 )

用 LoRA 做参数高效微调,由于算力资源有限,我们不会训练 LLM 的全部参数,而是用 LoRA(低秩适配)来提升训练效率。

 # 使用 LoRA 进行参数高效微调  
model = FastLanguageModel.get_peft_model(  
    model,  
    r = 8,   
    # 需要微调的模块  
    target_modules = [  
        "q_proj", "k_proj", "v_proj", "o_proj",  
        "gate_proj", "up_proj", "down_proj",  
    ],   
    lora_alpha = 8,  
    use_gradient_checkpointing = "unsloth",  
    random_state = 1234,  
 )

用著名的 GSM8K 数据集(小学到初中难度的数学文字题集合)来训练模型的推理能力。下面对数据集中的题目做格式化处理,以便用于训练。

 import re  
from datasets import load_dataset, Dataset  

# 系统提示词  
SYSTEM_PROMPT = """  
Respond in the following format:  
<reasoning>  
...  
</reasoning>  
<answer>  
...  
</answer>  
"""  

# 包裹推理与答案的模板  
XML_COT_FORMAT = """\  
<reasoning>  
{reasoning}  
</reasoning>  
<answer>  
{answer}  
</answer>  
"""  

# 从模型输出中抽取 <answer>...</answer> 内文本的函数  
def extract_xml_answer(text):  
    if "<answer>" not in text or "</answer>" not in text:  
        return ""  
    return text.split("<answer>", 1)[-1].split("</answer>", 1)[0].strip()  

# 从 GSM8K 标签中抽取正确答案,标签形如 '... #### final_answer'  
def extract_hash_answer(text):  
    return text.split("####")[-1].strip() if "####" in text else None  

# 加载 GSM8K 数据集并格式化为对话式提示的函数  
def get_gsm8k_dataset(split = "train"):  
    data = load_dataset("openai/gsm8k", "main")[split]  
    return data.map(  
        lambda x: {  
            "prompt": [  
                {"role": "system", "content": SYSTEM_PROMPT},  
                {"role": "user", "content": x["question"]},  
            ],  
            "answer": extract_hash_answer(x["answer"]),  
        }  
    )  

 dataset = get_gsm8k_dataset()

下面定义用来评估推理模型训练效果的奖励函数。

 # 奖励函数:检查从补全中抽取的答案  
# 是否与给定的真实答案完全一致。  
# 一致则返回 2.0,否则返回 0.0。  
def correctness_reward_func(prompts, completions, answer, **kwargs):  
    responses = [completion[0]['content'] for completion in completions]  
    q = prompts[0][-1]['content']  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")  
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]  

# 奖励函数:检查抽取的回答是否为整数。  
# 若为数字则返回 0.5,否则返回 0.0。  
def int_reward_func(completions, **kwargs):  
    responses = [completion[0]['content'] for completion in completions]  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]  

# 奖励函数:强约束的 XML 格式检查,  
# 要求响应必须严格匹配以下结构:  
# <reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n  
# 格式正确返回 0.5,否则返回 0.0。  
def strict_format_reward_func(completions, **kwargs):  
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  

# 奖励函数:较宽松的 XML 格式检查:  
# 响应需包含 <reasoning>...</reasoning> 与 <answer>...</answer>,  
# 但允许空格与换行的灵活性。  
# 匹配返回 0.5,否则返回 0.0。  
def soft_format_reward_func(completions, **kwargs):  
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  

# 辅助函数:统计并为 XML 标签计分  
def count_xml(text):  
    count = 0.0  
    if text.count("<reasoning>\n") == 1:  
        count += 0.125  
    if text.count("\n</reasoning>\n") == 1:  
        count += 0.125  
    if text.count("\n<answer>\n") == 1:  
        count += 0.125  
        count -= len(text.split("\n</answer>\n")[-1])*0.001  
    if text.count("\n</answer>") == 1:  
        count += 0.125  
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001  
    return count  

# 奖励函数:将 count_xml 应用于补全输出。  
# 根据标签正确性进行 XML 结构计分,并对尾部冗余内容施加惩罚。  
def xmlcount_reward_func(completions, **kwargs):  
    contents = [completion[0]["content"] for completion in completions]  
     return [count_xml(c) for c in contents]

定义 GRPO 训练器的参数。

 from trl import GRPOConfig, GRPOTrainer  

# 训练参数  
training_args = GRPOConfig(  
    use_vllm = True, # 使用 vLLM 进行快速推理  
    learning_rate = 5e-6,  
    adam_beta1 = 0.9,  
    adam_beta2 = 0.99,  
    weight_decay = 0.1,  
    warmup_ratio = 0.1,  
    lr_scheduler_type = "cosine",  
    optim = "adamw_8bit",  
    logging_steps = 1,  
    per_device_train_batch_size = 4,  
    gradient_accumulation_steps = 1,   
    num_generations = 4,   
    max_prompt_length = 256,  
    max_completion_length = 200,  
    max_steps = 250,  
    save_steps = 250,  
    max_grad_norm = 0.1,  
    report_to = "none",  
    output_dir = "outputs",  
)


# GRPO 训练器   
trainer = GRPOTrainer(  
    model = model,  
    processing_class = tokenizer,  
    reward_funcs = [  
        xmlcount_reward_func,  
        soft_format_reward_func,  
        strict_format_reward_func,  
        int_reward_func,  
        correctness_reward_func,  
    ],  
    args = training_args,  
    train_dataset = dataset,  
 )

用下面的命令启动训练。

 # 开始训练   
 trainer.train()

强化学习模型工作原理是探索解空间,所以训练通常比较慢。LLM 可能需要数百步才能学会更好的推理,这意味着你需要等几个小时才能得到不错的结果。

上图为部分训练日志

保存模型有很多方式,我们主要关心的是保存 LoRA 适配器。

 # 保存 LoRA 适配器  
 model.save_lora("grpo_saved_lora")

最后对比一下训练前后模型的输出。

from vllm import SamplingParams  

# 训练前的模型推理  
query = "How many r's are in strawberry?"  

text = tokenizer.apply_chat_template([  
    {"role" : "user", "content" : query},  
], tokenize = False, add_generation_prompt = True)  

sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  

output = model.fast_generate(  
    [text],  
    sampling_params = sampling_params,  
    lora_request = None,  
)[0].outputs[0].text  

print(output)

训练前模型的输出:

There are 2 r's in the word "strawberry."

接下来试试经过 GRPO 训练的模型:

# 训练后的模型推理  
text = tokenizer.apply_chat_template([  
    {"role" : "system", "content" : SYSTEM_PROMPT},  
    {"role" : "user", "content" : query},  
], tokenize = False, add_generation_prompt = True)  

sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  

output = model.fast_generate(  
    text,  
    sampling_params = sampling_params,  
    lora_request = model.load_lora("grpo_saved_lora"),  
)[0].outputs[0].text  

print(output)

训练后模型的输出:

<reasoning>  
To find out how many times the letter 'r' appears in the word "strawberry", we can go through the word character by character and count each occurrence of 'r'. In "strawberry", the letter 'r' appears 3 times: once in the beginning, once in the middle, and once at the end of the word.  
</reasoning>  
<answer>  
3  
</answer>

效果相当不错!可以看到模型现在会在回答问题前先进行推理,并且给出了正确答案。

下面是使用 GRPO 训练 Qwen 2.5(3B)训练过程的概览示意图:

本文的完整代码:

https://avoid.overfit.cn/post/1506330de8e349eab552ec1000417a27
作者:Dr. Ashish Bamania

目录
相关文章
|
23天前
|
负载均衡 测试技术 调度
大模型分布式推理:张量并行与流水线并行技术
本文深入探讨大语言模型分布式推理的核心技术——张量并行与流水线并行。通过分析单GPU内存限制下的模型部署挑战,详细解析张量并行的矩阵分片策略、流水线并行的阶段划分机制,以及二者的混合并行架构。文章包含完整的分布式推理框架实现、通信优化策略和性能调优指南,为千亿参数大模型的分布式部署提供全面解决方案。
368 4
|
1月前
|
机器学习/深度学习 缓存 监控
大模型推理优化技术:KV缓存机制详解
本文深入探讨了大语言模型推理过程中的关键技术——KV缓存(Key-Value Cache)机制。通过对Transformer自注意力机制的分析,阐述了KV缓存的工作原理、实现方式及其对推理性能的显著优化效果。文章包含具体的代码实现和性能对比数据,为开发者理解和应用这一关键技术提供实践指导。
683 8
|
28天前
|
人工智能 搜索推荐 程序员
当AI学会“跨界思考”:多模态模型如何重塑人工智能
当AI学会“跨界思考”:多模态模型如何重塑人工智能
235 120
|
1月前
|
机器学习/深度学习 缓存 自然语言处理
【万字长文】大模型训练推理和性能优化算法总结和实践
我们是阿里云公共云 AI 汽车行业大模型技术团队,致力于通过专业的全栈 AI 技术推动 AI 的落地应用。
1090 38
【万字长文】大模型训练推理和性能优化算法总结和实践
|
1月前
|
存储 监控 算法
1688 图片搜索逆向实战:CLIP 多模态融合与特征向量落地方案
本文分享基于CLIP模型与逆向工程实现1688图片搜同款的实战方案。通过抓包分析破解接口签名,结合CLIP多模态特征提取与Faiss向量检索,提升搜索准确率至91%,单次响应低于80ms,日均选品效率提升4倍,全程合规可复现。
|
1月前
|
机器学习/深度学习 存储 并行计算
大模型推理加速技术:FlashAttention原理与实现
本文深入解析大语言模型推理加速的核心技术——FlashAttention。通过分析传统注意力机制的计算瓶颈,详细阐述FlashAttention的IO感知算法设计、前向反向传播实现,以及其在GPU内存层次结构中的优化策略。文章包含完整的CUDA实现示例、性能基准测试和实际部署指南,为开发者提供高效注意力计算的全套解决方案。
266 10
|
27天前
|
缓存 物联网 PyTorch
使用TensorRT LLM构建和运行Qwen模型
本文档介绍如何在单GPU和单节点多GPU上使用TensorRT LLM构建和运行Qwen模型,涵盖模型转换、引擎构建、量化推理及LoRA微调等操作,并提供详细的代码示例与支持矩阵。
286 2
|
1月前
|
机器学习/深度学习 存储 缓存
大模型推理加速技术:PagedAttention原理与实现
本文深入解析大语言模型推理中的革命性技术——PagedAttention,该技术是vLLM推理引擎的核心创新。通过将操作系统中的虚拟内存分页概念引入注意力机制,PagedAttention有效解决了KV缓存的内存碎片问题,实现了近乎零浪费的KV缓存管理。文章详细阐述其原理、内存管理机制、实现细节,并提供完整的代码示例和性能分析。
196 1

热门文章

最新文章