大模型进阶微调篇(三):微调GPT2大模型实战

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 本文详细介绍了如何在普通个人电脑上微调GPT2大模型,包括环境配置、代码实现和技术要点。通过合理设置训练参数和优化代码,即使在无独显的设备上也能完成微调,耗时约14小时。文章还涵盖了GPT-2的简介、数据集处理、自定义进度条回调等内容,适合初学者参考。

在之前的两篇文章:基于人类反馈的强化学习RLHF原理、优点介绍 以定制化3B模型为例,各种微调方法对比-选LoRA还是PPO 介绍了一下微调相关的知识

在本文中,我带大家在一般设备上微调GPT2大模型(例如无GPU的ECS或者个人电脑) - qwen2.5-0.5B也可以,但时长需要80个小时左右对于集显太慢了。尽管大模型的训练通常需要强大的计算资源,但借助合理的配置和代码优化,可以在普通的个人电脑上对GPT2进行微调。
本次微调是在一台i5-9500 3GHz CPU,32GB内存,无独立显卡的计算资源上进行,训练时长大约为14个小时。本文将从环境配置、代码拆解、以及技术要点等方面进行详尽介绍。

GPT-2 简介

GPT-2 是由 OpenAI 开发的生成式预训练 Transformer(GPT)模型,是 GPT 系列模型的第二代版本。GPT-2 采用了 12 层 Transformer 编码器,拥有 1.5 亿参数。该模型基于大量文本数据进行无监督训练,具备了生成高质量自然语言的能力。GPT-2 适合多种下游任务,包括文本生成、对话系统和文本分类等。

环境配置与依赖安装

首先,需要安装以下主要依赖项(python3.8,建议用miniconda):

  • torch:用于深度学习模型的训练与推理。
  • transformers:用于加载 GPT-2 模型和处理相关任务。
  • datasets:用于加载和处理数据集。
  • tqdm:用于显示进度条。

通过以下命令来安装这些依赖:

pip install torch transformers datasets
AI 代码解读

代码拆解与说明

下面是完整的微调代码,将逐步对其进行拆解和说明。

1. 加载 GPT-2 模型和分词器

from transformers import GPT2Tokenizer, GPT2LMHeadModel

model_name = "openai-community/gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
AI 代码解读

首先,使用 GPT2LMHeadModel 加载 GPT-2 预训练模型,并使用 GPT2Tokenizer 加载对应的分词器。这里将 pad_token 设置为 eos_token,以便模型在处理填充时不会引入额外的混淆。

2. 加载数据集

from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
AI 代码解读

使用 datasets 库加载 Wikitext-2 数据集。这个数据集包含了丰富的自然语言文本,非常适合用于语言模型的微调。

3. 数据集分词

def tokenize_function(examples):
    tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
    tokens["labels"] = tokens["input_ids"].copy() #这个地儿是为了兼容,不然会报错无labels呢
    return tokens

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
AI 代码解读

为了适应 GPT-2 的输入,定义了 tokenize_function 对数据集进行分词,并将 input_ids 复制为 labels,使得模型在训练时可以通过对比输入和标签进行自回归学习。

4. 设置训练参数

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./gpt2-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    save_steps=1000,
    save_total_limit=2,
    eval_strategy="steps",
    report_to="none"
)
AI 代码解读

这里定义了模型训练的参数。主要的参数有:

  • output_dir:模型和检查点保存的目录。
  • overwrite_output_dir:是否覆盖已有的输出目录。
  • num_train_epochs:训练轮数,设置为3,表示模型将在整个数据集上迭代3次。
  • per_device_train_batch_size:每个设备上的训练批次大小为8,这会影响显存的占用和训练速度。
  • warmup_steps:学习率预热的步数,设置为500,用于在训练初期逐步提高学习率,避免模型收敛过快。
  • weight_decay:权重衰减系数,用于防止过拟合,设置为0.01。
  • logging_dir:日志保存的目录,用于保存 TensorBoard 等日志信息。
  • logging_steps:每50步记录一次日志,方便跟踪训练的进度和损失变化。
  • save_steps:每1000步保存一次模型,用于断点续训或选择最佳的训练结果。
  • save_total_limit:保存模型的最大数量,超过此数量后会自动删除旧的模型检查点,保持存储空间。
  • eval_strategy:设置为在训练过程中按步评估模型,以便及时了解模型的性能。
  • report_to:设置为 "none",表示不将训练日志报告给外部工具,如 TensorBoard。

5. 自定义进度条回调

为了更好地监控训练过程,定义了一个自定义进度条回调 ProgressCallback,使用 tqdm 显示训练进度:

from transformers import Trainer, TrainerCallback
from tqdm import tqdm

class ProgressCallback(TrainerCallback):
    def __init__(self, total_steps):
        self.progress_bar = tqdm(total=total_steps, desc="Training Progress")

    def on_step_end(self, args, state, control, **kwargs):
        self.progress_bar.update(1)

    def on_train_end(self, args, state, control, **kwargs):
        self.progress_bar.close()
AI 代码解读

6. 创建 Trainer

total_steps = len(tokenized_datasets) // training_args.per_device_train_batch_size * training_args.num_train_epochs

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    tokenizer=tokenizer,
    callbacks=[ProgressCallback(total_steps)]
)
AI 代码解读

使用 Trainer 类来管理训练过程,并添加自定义的进度条回调,以便可以实时查看训练进度。

7. 开始微调

trainer.train()
AI 代码解读

启动微调过程,Trainer 会自动处理训练循环,包括数据加载、前向传播、反向传播和梯度更新等。

8. 保存模型和评估

model.save_pretrained("./gpt2-finetuned")
tokenizer.save_pretrained("./gpt2-finetuned")

# 评估模型
eval_results = trainer.evaluate()
print(f"Perplexity: {torch.exp(torch.tensor(eval_results['eval_loss']))}")
AI 代码解读

训练完成后,将模型和分词器保存到指定目录。随后,通过 trainer.evaluate() 对模型进行评估,并计算困惑度(Perplexity),以评估模型生成文本的质量。

微调相关技术要点

  1. 微调原理:GPT-2 的微调基于预训练权重,使用下游任务数据继续训练,使模型在特定领域上具有更好的表现。
  2. 自回归训练:GPT-2 使用自回归的方式,即通过预测下一个词的方式学习文本的生成。
  3. 训练资源:由于 GPT-2 模型参数较多,训练过程通常较为耗时。在个人电脑上进行微调时,建议尽量使用较小的数据集,并适当减少训练轮数和批次大小。

总结

通过本文的实战演示,我成功地在一台普通的个人电脑上对 GPT-2 进行了微调。虽然设备性能有限,但借助优化的训练参数和合理的代码结构,依然能够完成大模型的微调任务。希望这篇文章能帮助你更好地理解和实践大模型的微调技术,让 AI 模型训练变得更加触手可及。

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
目录
打赏
0
7
6
1
58
分享
相关文章
智能体(AI Agent)开发实战之【LangChain】(二)结合大模型基于RAG实现本地知识库问答
智能体(AI Agent)开发实战之【LangChain】(二)结合大模型基于RAG实现本地知识库问答
魔塔社区-微调Qwen3-1.7B大模型实战
这是一篇关于模型微调实战的教程,主要步骤如下:1. 使用魔塔社区提供的GPU环境;2. 处理 delicate_medical_r1_data 数据集生成训练和验证文件;3. 加载Modelscope上的Qwen3-1.7B模型;4. 注册并使用Swanlab工具配置API;5. 按顺序执行完整代码完成微调设置;6. 展示训练过程。完整代码与实验记录分别托管于魔塔社区和SwanLab平台,方便复现与学习。
242 1
通义大模型与现有企业系统集成实战《CRM案例分析与安全最佳实践》
本文档详细介绍了基于通义大模型的CRM系统集成架构设计与优化实践。涵盖混合部署架构演进(新增向量缓存、双通道同步)、性能基准测试对比、客户意图分析模块、商机预测系统等核心功能实现。同时,深入探讨了安全防护体系、三级缓存架构、请求批处理优化及故障处理机制,并展示了实时客户画像生成和动态提示词工程。通过实施,显著提升客服响应速度(425%)、商机识别准确率(37%)及客户满意度(15%)。最后,规划了技术演进路线图,从单点集成迈向自主优化阶段,推动业务效率与价值持续增长。
基于通义大模型的智能客服系统构建实战:从模型微调到API部署
本文详细解析了基于通义大模型的智能客服系统构建全流程,涵盖数据准备、模型微调、性能优化及API部署等关键环节。通过实战案例与代码演示,展示了如何针对客服场景优化训练数据、高效微调大模型、解决部署中的延迟与并发问题,以及构建完整的API服务与监控体系。文章还探讨了性能优化进阶技术,如模型量化压缩和缓存策略,并提供了安全与合规实践建议。最终总结显示,微调后模型意图识别准确率提升14.3%,QPS从12.3提升至86.7,延迟降低74%。
187 13
Kaggle金牌方案复现:CGO-Transformer-GRU多模态融合预测实战
本文详细介绍了在2023年Kaggle "Global Multimodal Demand Forecasting Challenge"中夺冠的**CGO-Transformer-GRU**方案。该方案通过融合协方差引导优化(CGO)、注意力机制和时序建模技术,解决了多模态数据预测中的核心挑战,包括异构数据对齐、模态动态变化及长短期依赖建模。方案创新性地提出了动态门控机制、混合架构和梯度平衡算法,并在公开数据集TMU-MDFD上取得了RMSE 7.83的优异成绩,领先亚军12.6%。
Java 生态大模型应用开发全流程实战案例与技术路径终极对决
在Java生态中开发大模型应用,Spring AI、LangChain4j和JBoltAI是三大主流框架。本文从架构设计、核心功能、开发体验、性能扩展性、生态社区等维度对比三者特点,并结合实例分析选型建议。Spring AI适合已有Spring技术栈团队,LangChain4j灵活性强适用于学术研究,JBoltAI提供开箱即用的企业级解决方案,助力传统系统快速AI化改造。开发者可根据业务场景和技术背景选择最适合的框架。
108 2
JBoltAI 框架完整实操案例 在 Java 生态中快速构建大模型应用全流程实战指南
本案例基于JBoltAI框架,展示如何快速构建Java生态中的大模型应用——智能客服系统。系统面向电商平台,具备自动回答常见问题、意图识别、多轮对话理解及复杂问题转接人工等功能。采用Spring Boot+JBoltAI架构,集成向量数据库与大模型(如文心一言或通义千问)。内容涵盖需求分析、环境搭建、代码实现(知识库管理、核心服务、REST API)、前端界面开发及部署测试全流程,助你高效掌握大模型应用开发。
112 5
智能体(AI Agent)开发实战之【LangChain】(一)接入大模型输出结果
LangChain 是一个开源框架,专为构建与大语言模型(LLMs)相关的应用设计。通过集成多个 API、数据源和工具,助力开发者高效构建智能应用。本文介绍了 LangChain 的环境准备(如安装 LangChain、OpenAI 及国内 DeepSeek 等库)、代码实现(以国内开源大模型 Qwen 为例,展示接入及输出结果的全流程),以及核心参数配置说明。LangChain 的灵活性和强大功能使其成为开发对话式智能应用的理想选择。
小米又放大招!MiMo-VL 多模态大模型开源,魔搭推理微调全面解读来了!
今天,小米开源发布两款 7B 规模视觉-语言模型 MiMo-VL-7B-SFT 和 MiMo-VL-7B-RL。
292 9

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等