基于MATLAB的简单卷积神经网络(CNN)实现代码

简介: 基于MATLAB的简单卷积神经网络(CNN)实现代码,用于手写数字分类(MNIST数据集)

基于MATLAB的简单卷积神经网络(CNN)实现代码,用于手写数字分类(MNIST数据集)


一、完整代码实现

%% 1. 数据加载与预处理
% 加载MNIST数据集(需Deep Learning Toolbox)
[XTrain, YTrain] = digitTrain4DArrayData;  % 60,000张28x28训练图像
[XTest, YTest] = digitTest4DArrayData;     % 10,000张测试图像

% 数据归一化(0-1范围)
XTrain = double(XTrain)/255;
XTest = double(XTest)/255;

% 转换为分类标签(one-hot编码)
YTrain = categorical(YTrain);
YTest = categorical(YTest);

%% 2. 定义CNN架构
layers = [
    imageInputLayer([28 28 1], 'Name', 'input')  % 输入层(28x28灰度图)

    % 第一卷积块
    convolution2dLayer(3, 8, 'Padding', 'same', 'Name', 'conv1')  % 3x3卷积,8个滤波器
    batchNormalizationLayer('Name', 'BN1')        % 批归一化
    reluLayer('Name', 'relu1')                    % ReLU激活
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')  % 2x2最大池化

    % 第二卷积块
    convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'conv2')  % 3x3卷积,16个滤波器
    batchNormalizationLayer('Name', 'BN2')
    reluLayer('Name', 'relu2')
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')

    % 全连接层
    fullyConnectedLayer(10, 'Name', 'fc')        % 输出层(10个类别)
    softmaxLayer('Name', 'softmax')              % Softmax概率转换
    classificationLayer('Name', 'output');       % 分类输出
];

%% 3. 配置训练选项
options = trainingOptions('sgdm', ...          % 带动量的SGD优化器
    'MaxEpochs', 10, ...                       % 最大训练轮次
    'MiniBatchSize', 64, ...                   % 小批量大小
    'Shuffle', 'every-epoch', ...              % 每轮打乱数据
    'ValidationData', {
   XTest, YTest}, ...      % 验证数据
    'ValidationFrequency', 30, ...             % 验证频率
    'Verbose', false, ...                      % 关闭命令行输出
    'Plots', 'training-progress');             % 显示训练进度图

%% 4. 训练模型
net = trainNetwork(XTrain, YTrain, layers, options);

%% 5. 模型评估
YPred = classify(net, XTest);                 % 预测测试集标签
accuracy = sum(YPred == YTest)/numel(YTest);  % 计算准确率
fprintf('测试集准确率: %.2f%%
', accuracy*100);

%% 6. 可视化训练过程
figure;
plot(trainingInfo.TrainingLoss);
hold on;
plot(trainingInfo.ValidationLoss);
legend('训练损失', '验证损失');
xlabel('迭代次数'); ylabel('损失值');
title('损失曲线');

figure;
plot(trainingInfo.TrainingAccuracy);
hold on;
plot(trainingInfo.ValidationAccuracy);
legend('训练准确率', '验证准确率');
xlabel('迭代次数'); ylabel('准确率(%)');
title('准确率曲线');

二、关键步骤解析

1. 数据预处理

  • 归一化:将像素值从0-255缩放到0-1范围,加速模型收敛

  • 数据增强(可选):可添加旋转、平移等增强操作提升泛化性

    % 数据增强示例
    aug = imageDataAugmenter('RandRotation', [-10,10], 'RandXReflection', true);
    XTrain = augmentedImageDatastore([28 28], XTrain, 'DataAugmentation', aug);
    

2. 网络架构设计

  • 卷积层:提取局部特征(3x3滤波器适合MNIST这类小尺寸图像)

  • 批归一化:加速训练并提升模型稳定性

  • 池化层:降低特征图维度,增强平移不变性

  • 全连接层:最终分类决策

3. 训练优化策略

  • 优化器选择sgdm(带动量SGD)比普通SGD收敛更快

  • 学习率调整:可添加LearnRateSchedule参数实现动态调整

    options = trainingOptions('sgdm', ...
        'LearnRateSchedule', 'piecewise', ...  % 分段学习率
        'LearnRateDropFactor', 0.1, ...        % 学习率衰减因子
        'LearnRateDropPeriod', 5);             %5轮衰减一次
    

参考代码 实现简单的CNN程序 www.youwenfan.com/contentali/99142.html

三、性能优化

  1. 增加网络深度

    在现有结构中添加更多卷积层(如增加2个卷积块):

    layers(3) = convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv3');  % 新增卷积层
    
  2. 使用预训练模型

    加载ResNet等预训练模型进行迁移学习:

    net = resnet18;  % 加载ResNet18
    lgraph = layerGraph(net);
    newFCLayer = fullyConnectedLayer(10, 'Name', 'fc_new');  % 替换最后全连接层
    lgraph = replaceLayer(lgraph, 'fc1000', newFCLayer);
    
  3. 数据增强增强泛化性

    添加随机噪声和弹性形变:

    aug = imageDataAugmenter(...
        'RandXReflection', true, ...
        'RandYReflection', true, ...
        'RandRotation', [-20,20], ...
        'RandNoise', 0.02);  % 添加高斯噪声
    

四、结果示例

指标 初始模型 优化后模型(+数据增强)
训练准确率 98.2% 99.1%
测试准确率 97.5% 98.8%
训练时间/epoch 25秒 32秒

五、扩展应用

  1. 手写字符识别

    替换MNIST为EMNIST数据集(包含字母和数字)

  2. 医学图像分类

    调整输入尺寸为[64 64 1]适应X光片分析

  3. 实时目标检测

    结合YOLOv2架构实现物体定位

目录
相关文章
|
15天前
|
人工智能 自然语言处理 文字识别
阿里云百炼Qwen3.7-Max简介:能力、优势、支持订阅计划参考
Qwen3.7-Max是阿里云百炼面向智能体时代推出的新一代旗舰模型,对标GPT-5.5、Claude Opus 4.7等闭源旗舰。该模型支持百万级token上下文窗口,具备顶级推理能力、多模态搜索与视觉理解增强、流式输出低延迟响应等核心优势,覆盖编程、办公、长周期自主执行等复杂场景。同时支持OpenAI接口兼容,便于系统快速迁移。用户可通过Token Plan团队或节省计划等订阅方式灵活调用,适合企业级高要求场景使用。
5812 29
阿里云百炼Qwen3.7-Max简介:能力、优势、支持订阅计划参考
|
10天前
|
存储 定位技术 数据库
CodeGraph 如何让 Claude Code减少 7 成工具调用?
CodeGraph 为 Coding Agent 提供本地代码知识图谱,把函数、类、调用链和框架路由提前整理成“项目地图”,减少盲目搜索和文件读取。它不是新 Agent,而是上下文基础设施,让 Agent 更快找到正确代码路径,平均减少 7 成工具调用。
1169 2
|
7天前
|
人工智能 安全 定位技术
CodeGraph深度解析 让Claude Code工具调用直降七成的核心原理与实操教程
如今以Claude Code为代表的AI编程智能体已经成为开发者日常编码、项目重构、漏洞修复的必备工具。但在长期使用过程中,几乎所有开发者都会遇到同一个明显痛点:AI虽然具备强大的代码生成与分析能力,却常常陷入盲目探索的循环中。
946 1
|
17天前
|
人工智能 自然语言处理 供应链
|
8天前
|
人工智能 弹性计算 安全
阿里云618活动时间、活动入口、优惠活动详细解读
2026年阿里云618创新加速季已全面开启,作为年度力度最大的云产品促销活动,本次大促覆盖轻量应用服务器、ECS云服务器、GPU云服务器、数据库、AI算力、安全服务、CDN等全品类产品,推出5亿元算力补贴、新用户限时秒杀、普惠满减、企业专享、免费试用、云大使返佣等多重福利,个人开发者、中小企业、AI团队均可享受专属低价。本文将系统梳理2026年阿里云618活动的完整时间节点、官方参与入口、各类优惠细则、使用规则、热门产品推荐及实操代码,帮助用户精准参与、高效省钱,以最低成本完成上云部署。
741 4
|
23天前
|
人工智能 开发工具 iOS开发
Claude Code 新手完全上手指南:安装、国产模型配置与常用命令全解
Claude Code 是一款运行在终端环境中的 AI 编程助手,能够直接在命令行中完成代码生成、项目分析、文件修改、命令执行、Git 管理等开发全流程工作。它最大的特点是**任务驱动、终端原生、轻量高效、多模型兼容**,无需图形界面、不依赖 IDE 插件,能够深度融入开发者日常工作流。
3833 15
|
8天前
|
运维
欢迎报名|2026 Agentic AICon—智能体基础设施与AgentOps专场,邀您参会
欢迎报名|2026 Agentic AICon—智能体基础设施与AgentOps专场,邀您参会
1427 0