大模型进阶微调篇(三):微调GPT2大模型实战

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 本文详细介绍了如何在普通个人电脑上微调GPT2大模型,包括环境配置、代码实现和技术要点。通过合理设置训练参数和优化代码,即使在无独显的设备上也能完成微调,耗时约14小时。文章还涵盖了GPT-2的简介、数据集处理、自定义进度条回调等内容,适合初学者参考。

在之前的两篇文章:基于人类反馈的强化学习RLHF原理、优点介绍 以定制化3B模型为例,各种微调方法对比-选LoRA还是PPO 介绍了一下微调相关的知识

在本文中,我带大家在一般设备上微调GPT2大模型(例如无GPU的ECS或者个人电脑) - qwen2.5-0.5B也可以,但时长需要80个小时左右对于集显太慢了。尽管大模型的训练通常需要强大的计算资源,但借助合理的配置和代码优化,可以在普通的个人电脑上对GPT2进行微调。
本次微调是在一台i5-9500 3GHz CPU,32GB内存,无独立显卡的计算资源上进行,训练时长大约为14个小时。本文将从环境配置、代码拆解、以及技术要点等方面进行详尽介绍。

GPT-2 简介

GPT-2 是由 OpenAI 开发的生成式预训练 Transformer(GPT)模型,是 GPT 系列模型的第二代版本。GPT-2 采用了 12 层 Transformer 编码器,拥有 1.5 亿参数。该模型基于大量文本数据进行无监督训练,具备了生成高质量自然语言的能力。GPT-2 适合多种下游任务,包括文本生成、对话系统和文本分类等。

环境配置与依赖安装

首先,需要安装以下主要依赖项(python3.8,建议用miniconda):

  • torch:用于深度学习模型的训练与推理。
  • transformers:用于加载 GPT-2 模型和处理相关任务。
  • datasets:用于加载和处理数据集。
  • tqdm:用于显示进度条。

通过以下命令来安装这些依赖:

pip install torch transformers datasets

代码拆解与说明

下面是完整的微调代码,将逐步对其进行拆解和说明。

1. 加载 GPT-2 模型和分词器

from transformers import GPT2Tokenizer, GPT2LMHeadModel

model_name = "openai-community/gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

首先,使用 GPT2LMHeadModel 加载 GPT-2 预训练模型,并使用 GPT2Tokenizer 加载对应的分词器。这里将 pad_token 设置为 eos_token,以便模型在处理填充时不会引入额外的混淆。

2. 加载数据集

from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

使用 datasets 库加载 Wikitext-2 数据集。这个数据集包含了丰富的自然语言文本,非常适合用于语言模型的微调。

3. 数据集分词

def tokenize_function(examples):
    tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
    tokens["labels"] = tokens["input_ids"].copy() #这个地儿是为了兼容,不然会报错无labels呢
    return tokens

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

为了适应 GPT-2 的输入,定义了 tokenize_function 对数据集进行分词,并将 input_ids 复制为 labels,使得模型在训练时可以通过对比输入和标签进行自回归学习。

4. 设置训练参数

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./gpt2-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    save_steps=1000,
    save_total_limit=2,
    eval_strategy="steps",
    report_to="none"
)

这里定义了模型训练的参数。主要的参数有:

  • output_dir:模型和检查点保存的目录。
  • overwrite_output_dir:是否覆盖已有的输出目录。
  • num_train_epochs:训练轮数,设置为3,表示模型将在整个数据集上迭代3次。
  • per_device_train_batch_size:每个设备上的训练批次大小为8,这会影响显存的占用和训练速度。
  • warmup_steps:学习率预热的步数,设置为500,用于在训练初期逐步提高学习率,避免模型收敛过快。
  • weight_decay:权重衰减系数,用于防止过拟合,设置为0.01。
  • logging_dir:日志保存的目录,用于保存 TensorBoard 等日志信息。
  • logging_steps:每50步记录一次日志,方便跟踪训练的进度和损失变化。
  • save_steps:每1000步保存一次模型,用于断点续训或选择最佳的训练结果。
  • save_total_limit:保存模型的最大数量,超过此数量后会自动删除旧的模型检查点,保持存储空间。
  • eval_strategy:设置为在训练过程中按步评估模型,以便及时了解模型的性能。
  • report_to:设置为 "none",表示不将训练日志报告给外部工具,如 TensorBoard。

5. 自定义进度条回调

为了更好地监控训练过程,定义了一个自定义进度条回调 ProgressCallback,使用 tqdm 显示训练进度:

from transformers import Trainer, TrainerCallback
from tqdm import tqdm

class ProgressCallback(TrainerCallback):
    def __init__(self, total_steps):
        self.progress_bar = tqdm(total=total_steps, desc="Training Progress")

    def on_step_end(self, args, state, control, **kwargs):
        self.progress_bar.update(1)

    def on_train_end(self, args, state, control, **kwargs):
        self.progress_bar.close()

6. 创建 Trainer

total_steps = len(tokenized_datasets) // training_args.per_device_train_batch_size * training_args.num_train_epochs

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    tokenizer=tokenizer,
    callbacks=[ProgressCallback(total_steps)]
)

使用 Trainer 类来管理训练过程,并添加自定义的进度条回调,以便可以实时查看训练进度。

7. 开始微调

trainer.train()

启动微调过程,Trainer 会自动处理训练循环,包括数据加载、前向传播、反向传播和梯度更新等。

8. 保存模型和评估

model.save_pretrained("./gpt2-finetuned")
tokenizer.save_pretrained("./gpt2-finetuned")

# 评估模型
eval_results = trainer.evaluate()
print(f"Perplexity: {torch.exp(torch.tensor(eval_results['eval_loss']))}")

训练完成后,将模型和分词器保存到指定目录。随后,通过 trainer.evaluate() 对模型进行评估,并计算困惑度(Perplexity),以评估模型生成文本的质量。

微调相关技术要点

  1. 微调原理:GPT-2 的微调基于预训练权重,使用下游任务数据继续训练,使模型在特定领域上具有更好的表现。
  2. 自回归训练:GPT-2 使用自回归的方式,即通过预测下一个词的方式学习文本的生成。
  3. 训练资源:由于 GPT-2 模型参数较多,训练过程通常较为耗时。在个人电脑上进行微调时,建议尽量使用较小的数据集,并适当减少训练轮数和批次大小。

总结

通过本文的实战演示,我成功地在一台普通的个人电脑上对 GPT-2 进行了微调。虽然设备性能有限,但借助优化的训练参数和合理的代码结构,依然能够完成大模型的微调任务。希望这篇文章能帮助你更好地理解和实践大模型的微调技术,让 AI 模型训练变得更加触手可及。

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
2月前
|
SQL 数据采集 自然语言处理
NL2SQL之DB-GPT-Hub<详解篇>:text2sql任务的微调框架和基准对比
NL2SQL之DB-GPT-Hub<详解篇>:text2sql任务的微调框架和基准对比
|
5天前
|
人工智能 API Windows
免费部署本地AI大语言模型聊天系统:Chatbox AI + 马斯克grok2.0大模型(简单5步实现,免费且比GPT4.0更好用)
本文介绍了如何部署本地AI大语言模型聊天系统,使用Chatbox AI客户端应用和Grok-beta大模型。通过获取API密钥、下载并安装Chatbox AI、配置模型,最终实现高效、智能的聊天体验。Grok 2大模型由马斯克X-AI发布,支持超长文本上下文理解,免费且易于使用。
33 0
|
3月前
|
人工智能 自然语言处理 算法
魔搭上新啦! 智源千万级指令微调数据集Infinity-Instruct,Llama3.1仅微调即可接近GPT-4
智源研究院在今年6月推出了千万级指令微调数据集Infinity Instruct。Infinity Instruct在 Huggingface等平台发布后,快速到达了Huggingface Dataset的Trending第一
魔搭上新啦! 智源千万级指令微调数据集Infinity-Instruct,Llama3.1仅微调即可接近GPT-4
|
2月前
|
机器学习/深度学习 测试技术
ACL杰出论文奖:GPT-4V暴露致命缺陷?JHU等发布首个多模态ToM 测试集,全面提升大模型心智能力
【10月更文挑战第6天】约翰斯·霍普金斯大学等机构提出了一项荣获ACL杰出论文奖的研究,旨在解决大模型在心智理论(ToM)上的不足。他们发布了首个MMToM-QA多模态ToM测试集,并提出BIP-ALM方法,从多模态数据中提取统一表示,结合语言模型进行贝叶斯逆规划,显著提升了模型的ToM能力。这一成果为机器与人类自然交互提供了新思路,尽管仍面临一些局限性和技术挑战。论文详情见:https://arxiv.org/abs/2401.08743。
51 6
|
3月前
|
数据采集 自然语言处理 监控
大模型微调使GPT3成为了可以聊天发布指令的ChatGPT
正是通过微调大模型使得GPT3成为了可以聊天发布指令的ChatGPT。聊天大模型在通用大模型的基础上加一层微调就实现人人能用的大模型,使得通用大模型的能力被更多人使用和了解。
62 4
大模型微调使GPT3成为了可以聊天发布指令的ChatGPT
|
2月前
|
开发工具 git
LLM-03 大模型 15分钟 FineTuning 微调 GPT2 模型 finetuning GPT微调实战 仅需6GB显存 单卡微调 数据 10MB数据集微调
LLM-03 大模型 15分钟 FineTuning 微调 GPT2 模型 finetuning GPT微调实战 仅需6GB显存 单卡微调 数据 10MB数据集微调
73 0
|
5月前
|
存储 SQL 数据库
Python 金融编程第二版(GPT 重译)(四)(4)
Python 金融编程第二版(GPT 重译)(四)
53 3
|
5月前
|
存储 NoSQL 索引
Python 金融编程第二版(GPT 重译)(一)(4)
Python 金融编程第二版(GPT 重译)(一)
64 2
|
5月前
|
存储 机器学习/深度学习 关系型数据库
Python 金融编程第二版(GPT 重译)(四)(5)
Python 金融编程第二版(GPT 重译)(四)
38 2
|
5月前
|
存储 SQL 数据可视化
Python 金融编程第二版(GPT 重译)(四)(1)
Python 金融编程第二版(GPT 重译)(四)
51 2
下一篇
DataWorks