PyTorch深度学习实战 |基于Alexnet网络预训练模型完成训练花分类任务实战

本文涉及的产品
RDS DuckDB + QuickBI 企业套餐,8核32GB + QuickBI 专业版
简介: 本文介绍了使用AlexNet模型进行花卉图像分类的实战过程。首先讲解了数据集的准备方法,包括5类花卉数据(雏菊、蒲公英等)的8:2训练集/验证集划分。详细解析了AlexNet的网络结构(5个卷积层+3个全连接层)及其创新点,如ReLU激活函数和Dropout正则化。提供了完整的PyTorch实现代码,包括模型定义、数据增强和训练流程。实验结果表明,50轮训练后验证集准确率可达80%。文章还介绍了使用预训练模型进行迁移学习的方法,通过修改分类器层并微调参数,可以显著提升训练效率和分类效果。整个项目从数据准备到

 使用的数据集

花分类数据集:

百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg

提取码:58p0

下载好之后,解压到flower_data文件夹下,此时flower_data\flower_photos下就是放的我们的数据

集,我们看一下原始的数据是什么样子的:

分类类别:共包含 5 类花卉,对应 5 个文件夹: daisy(雏菊) dandelion(蒲公英) roses(玫

瑰) sunflowers(向日葵) tulips(郁金香)

image.gif

跑过一些项目的应该都有印象,比如YOLO等,他们的数据集的放置是有要求的一般情况下都是分

成两个,一个是train文件夹,train文件夹下是各种分类的文件夹(每个文件夹的名字是类报名)。

另外一个是val文件夹,val文件夹下是各种分类的文件夹(每个文件夹的名字是类报名)。一般是

按照8:2的比例去分这两个数据集的。这里的话可以用AI写代码整理,但是别忘记了检查一下。

训练集的路径:D:\vscode\shenduxvexishizhan\CNN\flower_data\train

验证集的路径是:D:\vscode\shenduxvexishizhan\CNN\flower_data\val


Alexnet

AlexNet创新点

(1)AlexNet首次成功使用了8层深度网络(5个卷积层 + 3个全连接层),比之前的网络深得多

(2)首次在深层网络中大规模且成功地使用了(ReLU)作为激活函数,取代了传统的 Sigmoid

或 Tanh 函数。

(3)引入GPU 加速训练 (GPU Acceleration),Dropout 正则化 (Dropout Regularization),数据增

强 (Data 局部响应归一化(LRN)和重叠池化(Overlapping Pooling)Augmentation)

网络结构

层类型 具体参数 输入尺寸 输出尺寸 作用
卷积层 1 Conv2d (3→48, 11×11, 步长 4, 填充 2) [3,224,224] [48,55,55] 提取基础纹理特征
池化层 1 MaxPool2d (3×3, 步长 2) [48,55,55] [48,27,27] 降维 + 增强鲁棒性
卷积层 2 Conv2d (48→128, 5×5, 填充 2) [48,27,27] [128,27,27] 提取更复杂特征
池化层 2 MaxPool2d (3×3, 步长 2) [128,27,27] [128,13,13] 继续降维
卷积层 3 Conv2d (128→192, 3×3, 填充 1) [128,13,13] [192,13,13] 特征细化
卷积层 4 Conv2d (192→192, 3×3, 填充 1) [192,13,13] [192,13,13] 特征细化
卷积层 5 Conv2d (192→128, 3×3, 填充 1) [192,13,13] [128,13,13] 特征压缩
池化层 3 MaxPool2d (3×3, 步长 2) [128,13,13] [128,6,6] 最终降维
全连接层 1 Linear(128×6×6 → 2048) 4608 2048 特征映射到高维空间
全连接层 2 Linear(2048 → 2048) 2048 2048 特征变换
全连接层 3 Linear(2048 → num_classes) 2048 num_classes 最终分类

image.gif

model.py

import torch.nn as nn
import torch
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
# 测试模型前向传播
if __name__ == "__main__":
    # 创建模型实例
    model = AlexNet(num_classes=1000, init_weights=True)
    # 生成随机输入(batch_size=4, 3通道, 224×224)
    input_tensor = torch.randn(4, 3, 224, 224)
    # 前向传播
    output = model(input_tensor)
    # 打印输出形状
    print(f"输入形状: {input_tensor.shape}")
    print(f"输出形状: {output.shape}")  # 应输出 [4, 1000]
    # 打印模型参数量
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"总参数量: {total_params/1e6:.2f}M")
    print(f"可训练参数量: {trainable_params/1e6:.2f}M")

image.gif

模型的输入是【B,3,224,224】代表B张图片,输出是【B,5】代表B个样本,每个类别的概

率。在实际的项目中,这是最简单,也是最核心的部分。简单是因为,所有神经网络的本质都是为

了提取特征的,所有我们很多时候不需要知道,其是怎么实现的,只需要知道,网络的输入和输出

就行。最核心是因为,往往特征提取的好坏直接决定了训练效果的好坏。

image.gif

图解处理和图像增强:

在验证的时候是不需要数据增强的

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

image.gif

train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))
    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    print('Finished Training')
if __name__ == '__main__':
    main()

image.gif

训练结果:

训练50代的效果如下:验证集的准确度大致可以稳定在0.8左右

image.gif


预训练模型完成训练

加载预训练权重: 导入 ImageNet 上训练好的 AlexNet 权重。

修改分类器: 由于你的任务是 5 类花卉分类,需要替换 AlexNet 原本的 1000 类输出层,以匹配你的 num_classes=5。

设置学习率/冻结层: 通常对预训练模型使用更小的学习率,或者冻结特征提取层(features)的参数,只训练分类器(classifier)的参数。

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models # 导入 models
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet # 假设这是你自定义的 AlexNet 类
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    # ... (数据加载和处理部分保持不变) ...
    # ------------------ 【修改开始】模型加载与设置 ------------------
    # 1. 实例化 AlexNet 模型
    # 如果使用你自定义的 AlexNet 类,这里加载预训练权重(需要文件支持)
    net = AlexNet(num_classes=1000, init_weights=False) # 实例化1000类模型
    
    # 假设你的预训练权重是 'alexnet_imagenet.pth'
    # weights_path = "./alexnet_imagenet.pth"
    # assert os.path.exists(weights_path), f"Pretrained weights file: '{weights_path}' not found."
    # net.load_state_dict(torch.load(weights_path), strict=False) # strict=False 如果你的类和权重不完全匹配
    # 或者:使用官方预训练模型(更简单)
    net = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    # 2. 替换分类器 (迁移学习的关键)
    in_features = net.classifier[6].in_features # 官方 AlexNet 的最后一个 Linear 层是第 6 个模块 (索引从 0 开始)
    # 替换为 5 个类别的输出
    net.classifier[6] = nn.Linear(in_features, 5) 
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    
    # 3. 设置微调学习率(通常更小)
    # 只训练分类器参数,使用更高的学习率:
    # optimizer = optim.Adam(net.classifier.parameters(), lr=0.001) 
    
    # 或者,微调所有参数,使用更小的学习率:
    optimizer = optim.Adam(net.parameters(), lr=0.00005) # 降低学习率进行微调
    # ------------------ 【修改结束】模型加载与设置 ------------------
    epochs = 50
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    print('Finished Training')
if __name__ == '__main__':
    main()

image.gif

训练效果:

我们应该可以直观的感受到,这种方式的训练速度更快,效果更好

image.gif


目录
相关文章
|
16天前
|
人工智能 自然语言处理 文字识别
阿里云百炼Qwen3.7-Max简介:能力、优势、支持订阅计划参考
Qwen3.7-Max是阿里云百炼面向智能体时代推出的新一代旗舰模型,对标GPT-5.5、Claude Opus 4.7等闭源旗舰。该模型支持百万级token上下文窗口,具备顶级推理能力、多模态搜索与视觉理解增强、流式输出低延迟响应等核心优势,覆盖编程、办公、长周期自主执行等复杂场景。同时支持OpenAI接口兼容,便于系统快速迁移。用户可通过Token Plan团队或节省计划等订阅方式灵活调用,适合企业级高要求场景使用。
5871 30
阿里云百炼Qwen3.7-Max简介:能力、优势、支持订阅计划参考
|
1天前
|
数据采集 人工智能 前端开发
让 Coding Agent 从黑盒到透明:阿里云 Agent 观测审计数据采集实践
AI Agent 规模化落地带来执行黑盒、行为难追溯、成本难度量三大难题。阿里云基于 OTel 标准,面向 Coding Agent、个人通用助理和框架型 Agent,推出 LoongSuite Pilot、插件及探针等无侵入采集方案,让 Agent 实现可看见、可分析、可审计、可治理。
561 134
|
10天前
|
存储 定位技术 数据库
CodeGraph 如何让 Claude Code减少 7 成工具调用?
CodeGraph 为 Coding Agent 提供本地代码知识图谱,把函数、类、调用链和框架路由提前整理成“项目地图”,减少盲目搜索和文件读取。它不是新 Agent,而是上下文基础设施,让 Agent 更快找到正确代码路径,平均减少 7 成工具调用。
1177 2
|
8天前
|
人工智能 安全 定位技术
CodeGraph深度解析 让Claude Code工具调用直降七成的核心原理与实操教程
如今以Claude Code为代表的AI编程智能体已经成为开发者日常编码、项目重构、漏洞修复的必备工具。但在长期使用过程中,几乎所有开发者都会遇到同一个明显痛点:AI虽然具备强大的代码生成与分析能力,却常常陷入盲目探索的循环中。
959 1
|
17天前
|
人工智能 自然语言处理 供应链
|
8天前
|
人工智能 弹性计算 安全
阿里云618活动时间、活动入口、优惠活动详细解读
2026年阿里云618创新加速季已全面开启,作为年度力度最大的云产品促销活动,本次大促覆盖轻量应用服务器、ECS云服务器、GPU云服务器、数据库、AI算力、安全服务、CDN等全品类产品,推出5亿元算力补贴、新用户限时秒杀、普惠满减、企业专享、免费试用、云大使返佣等多重福利,个人开发者、中小企业、AI团队均可享受专属低价。本文将系统梳理2026年阿里云618活动的完整时间节点、官方参与入口、各类优惠细则、使用规则、热门产品推荐及实操代码,帮助用户精准参与、高效省钱,以最低成本完成上云部署。
764 4
|
8天前
|
运维
欢迎报名|2026 Agentic AICon—智能体基础设施与AgentOps专场,邀您参会
欢迎报名|2026 Agentic AICon—智能体基础设施与AgentOps专场,邀您参会
1432 0