从零训练一个 ChatGPT:用 PyTorch 构建自己的 LLM 模型

简介: 本文介绍如何使用PyTorch从零构建类似ChatGPT的大型语言模型,涵盖Transformer架构、数据预处理、训练优化及文本生成全过程,助你掌握LLM核心原理与实现技术。(238字)

从零训练一个 ChatGPT:用 PyTorch 构建自己的 LLM 模型

大型语言模型(LLM)的发展正在重塑人工智能领域,从GPT系列到BERT,这些模型展现了惊人的语言理解和生成能力。本文将详细介绍如何使用PyTorch从零开始构建一个类似ChatGPT的大型语言模型,涵盖模型架构、训练策略、优化技术等关键环节。

模型架构设计

image.png

Transformer架构基础

Transformer模型是现代LLM的核心架构,基于自注意力机制实现并行化的序列建模。我们首先构建基础的Transformer组件:

基础注意力机制实现

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)

        attention_weights = torch.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output = torch.matmul(attention_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(output)

        return output

解码器架构

ChatGPT基于GPT架构,使用纯解码器结构:

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()

        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 自注意力层
        attended = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attended))

        # 前馈网络
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x

class GPTModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_layers=6, num_heads=8, d_ff=2048, max_seq_len=512, dropout=0.1):
        super(GPTModel, self).__init__()

        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self.create_positional_encoding(max_seq_len, d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()

        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        return pe.unsqueeze(0)

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def forward(self, src, src_mask=None):
        seq_len = src.size(1)

        if src_mask is None:
            src_mask = self.generate_square_subsequent_mask(seq_len).to(src.device)

        x = self.embedding(src) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len, :].to(src.device)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, src_mask)

        output = self.fc_out(x)
        return output

数据预处理与分词

文本清洗和预处理

import re
import unicodedata

def clean_text(text):
    # 转换为小写
    text = text.lower()

    # 移除特殊字符
    text = re.sub(r'[^\w\s]', ' ', text)

    # 标准化Unicode字符
    text = unicodedata.normalize('NFKD', text)

    # 移除多余空格
    text = re.sub(r'\s+', ' ', text).strip()

    return text

简单词汇表构建

class Vocabulary:
    def __init__(self):
        self.word2idx = {
   '<PAD>': 0, '<UNK>': 1, '<START>': 2, '<END>': 3}
        self.idx2word = {
   0: '<PAD>', 1: '<UNK>', 2: '<START>', 3: '<END>'}
        self.vocab_size = 4

    def build_vocab(self, texts):
        word_count = {
   }

        for text in texts:
            for word in text.split():
                if word not in word_count:
                    word_count[word] = 0
                word_count[word] += 1

        # 过滤低频词
        for word, count in word_count.items():
            if count >= 2:  # 最小频率阈值
                self.word2idx[word] = self.vocab_size
                self.idx2word[self.vocab_size] = word
                self.vocab_size += 1

数据集类

在之前先给大家分享一个免费获取数据集的网站-魔塔

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, vocab, max_length=128):
        self.texts = texts
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = text.split()[:self.max_length-2]  # 预留开始和结束标记

        # 添加开始和结束标记
        tokens = ['<START>'] + tokens + ['<END>']

        # 转换为索引
        indices = [self.vocab.word2idx.get(token, self.vocab.word2idx['<UNK>']) 
                  for token in tokens]

        # 填充到固定长度
        if len(indices) < self.max_length:
            indices.extend([self.vocab.word2idx['<PAD>']] * (self.max_length - len(indices)))
        else:
            indices = indices[:self.max_length]

        # 创建输入和目标
        input_ids = torch.tensor(indices[:-1], dtype=torch.long)
        target_ids = torch.tensor(indices[1:], dtype=torch.long)

        return input_ids, target_ids

训练策略

损失函数和优化器配置

def setup_training(model, learning_rate=1e-4):
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略PAD标记
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-6
    )

    return criterion, optimizer, scheduler

训练循环

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for batch_idx, (src, tgt) in enumerate(dataloader):
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()

        # 前向传播
        output = model(src)

        # 计算损失
        loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))

        # 反向传播
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')

    return total_loss / len(dataloader)

模型评估

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)

            output = model(src)
            loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))

            total_loss += loss.item()

            # 计算准确率
            predictions = output.argmax(dim=-1)
            mask = tgt != 0  # 排除PAD标记
            correct_predictions += ((predictions == tgt) * mask).sum().item()
            total_predictions += mask.sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

    return avg_loss, accuracy

推理和文本生成

文本生成函数

def generate_text(model, start_text, vocab, max_length=100, temperature=1.0, device='cpu'):
    model.eval()

    # 预处理起始文本
    tokens = start_text.split()
    input_ids = [vocab.word2idx.get(token, vocab.word2idx['<UNK>']) for token in tokens]
    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)

    generated = input_ids.tolist()[0]

    with torch.no_grad():
        for _ in range(max_length):
            # 获取模型输出
            output = model(torch.tensor([generated], dtype=torch.long).to(device))
            next_token_logits = output[0, -1, :] / temperature

            # 应用softmax获取概率
            probabilities = torch.softmax(next_token_logits, dim=-1)

            # 采样下一个token
            next_token = torch.multinomial(probabilities, 1).item()

            # 如果生成结束标记,则停止
            if next_token == vocab.word2idx['<END>']:
                break

            generated.append(next_token)

    # 转换回文本
    generated_text = []
    for token_id in generated:
        word = vocab.idx2word.get(token_id, '<UNK>')
        if word not in ['<START>', '<END>', '<PAD>']:
            generated_text.append(word)

    return ' '.join(generated_text)

交互式生成

def interactive_generation(model, vocab, device='cpu'):
    print("开始交互式文本生成,输入'quit'退出")

    while True:
        prompt = input("请输入提示文本: ")
        if prompt.lower() == 'quit':
            break

        generated_text = generate_text(model, prompt, vocab, max_length=50, device=device)
        print(f"生成文本: {generated_text}\n")

高级优化技术

梯度累积

def train_with_gradient_accumulation(model, dataloader, criterion, optimizer, device, accumulation_steps=4):
    model.train()
    total_loss = 0

    optimizer.zero_grad()

    for batch_idx, (src, tgt) in enumerate(dataloader):
        src, tgt = src.to(device), tgt.to(device)

        output = model(src)
        loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
        loss = loss / accumulation_steps  # 归一化损失

        loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps

    return total_loss / len(dataloader)

混合精度训练

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    scaler = GradScaler()

    for batch_idx, (src, tgt) in enumerate(dataloader):
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()

        with autocast():
            output = model(src)
            loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    return total_loss / len(dataloader)

模型微调策略

参数高效微调

class LoRALayer(nn.Module):
    def __init__(self, d_model, rank=8):
        super(LoRALayer, self).__init__()
        self.A = nn.Parameter(torch.randn(d_model, rank) * 0.01)
        self.B = nn.Parameter(torch.zeros(rank, d_model))

    def forward(self, x):
        return torch.matmul(torch.matmul(x, self.A), self.B)

模型保存和加载

def save_model(model, optimizer, epoch, loss, filepath):
    torch.save({
   
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filepath)

def load_model(model, optimizer, filepath):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss

实际应用示例

完整训练流程

def main():
    # 配置参数
    d_model = 256
    num_layers = 4
    num_heads = 8
    d_ff = 512
    vocab_size = 10000
    max_seq_len = 128
    batch_size = 32
    epochs = 10

    # 检查GPU可用性
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 示例数据(实际应用中应使用真实数据集)
    sample_texts = [
        "the quick brown fox jumps over the lazy dog",
        "machine learning is a subset of artificial intelligence",
        "deep learning models require large amounts of data",
        # 更多训练数据...
    ]

    # 构建词汇表
    vocab = Vocabulary()
    vocab.build_vocab(sample_texts)

    # 创建数据集
    dataset = TextDataset(sample_texts, vocab, max_seq_len)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 初始化模型
    model = GPTModel(
        vocab_size=vocab.vocab_size,
        d_model=d_model,
        num_layers=num_layers,
        num_heads=num_heads,
        d_ff=d_ff,
        max_seq_len=max_seq_len
    ).to(device)

    # 设置训练参数
    criterion, optimizer, scheduler = setup_training(model)

    # 训练循环
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        train_loss = train_epoch(model, dataloader, criterion, optimizer, device)
        print(f"Train Loss: {train_loss:.4f}")

        scheduler.step()

        # 保存检查点
        if (epoch + 1) % 5 == 0:
            save_model(model, optimizer, epoch, train_loss, f'checkpoint_epoch_{epoch+1}.pth')

    print("训练完成!")

    # 文本生成示例
    generated_text = generate_text(model, "machine learning", vocab, max_length=30, device=device)
    print(f"生成示例: {generated_text}")

if __name__ == "__main__":
    main()

性能优化建议

内存优化策略

  1. 梯度累积:在内存受限时使用梯度累积
  2. 混合精度训练:使用FP16减少内存占用
  3. 模型并行:将模型分布在多个GPU上
  4. 激活检查点:减少中间激活的内存占用

计算优化策略

  1. 分布式训练:使用多GPU加速训练
  2. 数据并行:将数据分发到多个设备
  3. 流水线并行:将模型层分布在不同设备上

总结

构建一个类似ChatGPT的大型语言模型是一个复杂的工程,涉及模型架构设计、数据预处理、训练策略、优化技术等多个方面。通过PyTorch框架,我们可以从零开始实现这些组件,并逐步优化模型性能。

实际应用中,还需要考虑:

  • 大规模数据集的处理
  • 分布式训练的实现
  • 模型压缩和量化
  • 推理优化技术
  • 安全性和伦理考量

随着技术的不断发展,LLM的训练和部署将变得更加高效和便捷,为各种应用场景提供强大的语言理解能力。



关于作者



🌟 我是suxiaoxiang,一位热爱技术的开发者

💡 专注于Java生态和前沿技术分享

🚀 持续输出高质量技术内容



如果这篇文章对你有帮助,请支持一下:




👍 点赞


收藏


👀 关注



您的支持是我持续创作的动力!感谢每一位读者的关注与认可!


目录
相关文章
|
20天前
|
机器学习/深度学习 人工智能 搜索推荐
数据中台的进化之路:从“管数据”到“懂业务”
数据中台的进化之路:从“管数据”到“懂业务”
140 3
|
20天前
|
运维 自然语言处理 监控
AIOps 实战:我用 LLM 辅助分析线上告警
本文分享AIOps实战中利用大型语言模型(LLM)智能分析线上告警的实践经验,解决告警洪流、关联性分析难等问题。通过语义理解与上下文感知,LLM实现告警分类、优先级排序与根因定位,显著提升运维效率与准确率,助力系统稳定运行。
129 5
|
2月前
|
人工智能 文字识别 并行计算
为什么别人用 DevPod 秒启 DeepSeek-OCR,你还在装环境?
DevPod 60秒极速启动,一键运行DeepSeek OCR大模型。告别环境配置难题,云端开箱即用,支持GPU加速、VSCode/Jupyter交互开发,重塑AI原生高效工作流。
623 35
|
20天前
|
SQL 人工智能 自然语言处理
Spring Boot + GPT:我做了一个能自己写 SQL 的后端系统
本文介绍如何基于Spring Boot与GPT(或国产大模型如通义千问、DeepSeek)构建智能后端系统,实现自然语言自动生成SQL。系统采用分层架构,集成AI语义理解、SQL安全验证与执行功能,提升开发效率并降低数据查询门槛,兼具安全性与可扩展性。
143 7
|
缓存 网络协议 算法
Netty的基础入门(上)
Netty的基础入门(上)
523 1
|
20天前
|
数据采集 人工智能 搜索推荐
别再“调教”ChatGPT了!用Qwen2.5打造24小时在线数字分身
在AI时代,专属“数字分身”正从科幻走向现实。依托Qwen2.5-14B大模型、LoRA微调技术及LLaMA-Factory Online平台,仅需四步即可打造会说话、懂风格、能办事的个性化AI助手,让每个人拥有自己的“贾维斯”。
331 153
|
19天前
|
人工智能 运维 安全
当Java遇见AI:无需Python,构建企业级RAG智能应用实战
本文深入探讨Java在RAG(检索增强生成)智能应用中的实战应用,打破“AI等于Python”的固有认知。依托Spring生态、高性能向量计算与企业级安全监控,结合文档预处理、混合检索、重排序与多LLM集成,构建高并发、可运维的生产级系统。展示如何用Java实现从文本分割、向量化到智能生成的全流程,助力企业高效落地AI能力,兼具性能、安全与可扩展性。
175 1
|
2月前
|
机器学习/深度学习 存储 缓存
115_LLM基础模型架构设计:从Transformer到稀疏注意力
大型语言模型(LLM)的架构设计是其性能的核心决定因素。从2017年Transformer架构的提出,到如今的稀疏注意力和混合专家模型,LLM架构经历了快速的演进。本文将全面探讨LLM基础架构的设计原理,深入分析Transformer的核心机制,详细介绍稀疏注意力、MoE等创新架构,并展望未来架构发展方向。通过数学推导和实践案例,为构建高效、强大的LLM提供全面指导。
|
消息中间件 存储 资源调度
订单超时处理的几种方案及分析
描述业务常见的订单超时处理的几种方案及分析
33021 19
订单超时处理的几种方案及分析