【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型

简介: 【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型

MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型


一、基本介绍


这篇文章展示了如何使用不平衡训练数据集对图像进行分类,其中每个类的图像数量在类之间不同。两种最流行的解决方案是down-sampling降采样和over-sampling过采样。


在降采样中,每个类别的图像数量减少到所有类别中的最小图像数量。降采样的实现很容易:只需使用splitEachLabel函数并指定类的最小数量,


另一方面,当执行过采样时,每个类别的图像数量增加。这两种策略对于不平衡的数据集都是有效的。然而,过采样需要更复杂的过程。


本篇文章采用过采样平衡数据

       Label        Count
_____________    _____
caesar_salad       13 
caprese_salad       8 
french_fries       91 
greek_salad        12 
hamburger         119 
hot_dog            16 
pizza             150 
sashimi            20 
sushi              62 


过采样结果:

Label        Count
_____________    _____
caesar_salad      150 
caprese_salad     150 
french_fries      150 
greek_salad       150 
hamburger         150 
hot_dog           150 
pizza             150 
sashimi           150 
sushi             150 

二、数据情况

食品图像数据集包含九类食物的978张照片(ceaser_salad、caprese_salad,french_fries、greek_saland、汉堡包、hot_dog、披萨、生鱼片和寿司)。数据集可在下列地址下载

https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip

本文为了提高运行速度 ,选择80%训练, 10%验证,10%测试。


三、代码展示

1.导入数据

imds = imageDatastore('ExampleFoodImageDataset');

2.图像数据展示

numExample=16;
idx = randperm(numel(imds.Files),numExample);
for i=1:numExample
    I=readimage(imds,idx(i));
    I_tile{i}=insertText(I,[1,1],string(imds.Labels(idx(i))),'FontSize',20);
end
% use imtile function to tile out the example images
I_tile = imtile(I_tile);
figure()
imshow(I_tile);title('examples of the dataset')

3.数据集划分 (训练80%,验证10%,测试10%)

[imdsTrain, imdsValid,imdsTest]=splitEachLabel(imds,0.8,0.1,0.1);

4.选取最大样本数

PerClass是所有类中的最大样本数。

PerClass = max(numObservations);

5.平衡数据

randReplicateFiles是一个仅对文件进行混洗的支持功能。

要选择的图像数量由PerClass定义。从数据库中找到不同类别的图像目录,然后随机复制对应的图像至对应的数量,以平衡类中的图像数量。

files = splitapply(@(x){randReplicateFiles(x,desiredNumObservationsPerClass)},imdsTrain.Files,G);

6.构建网络

加载预先训练的模型,ResNet-18

net = resnet18;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
learnableLayer='fc1000';
classLayer='ClassificationLayer_predictions';

7.图像增强

定义图像增强器
pixelRange = [-30 30];
RotationRange = [-30 30];
scaleRange = [0.8 1.2];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange, ...
    'RandRotation',RotationRange ...
    ); 

8.设置网络参数

miniBatchSize = 64;
valFrequency = max(floor(numel(augimdsTest.Files)/miniBatchSize)*10,1);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',5, ...%30
    'InitialLearnRate',1e-2, ...%3e-4
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValid, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

9.训练网络

net = trainNetwork(augimdsTrain,lgraph,options);

10.分类评估

[YPred,probs] = classify(net,augimdsTest);
accuracy = mean(YPred == imdsTest.Labels)
YValidation = imdsTest.Labels;
YTrue=imdsTest.Labels;
figure;cm=confusionchart(YTrue,YPred);

%当我运行这个代码时,主要的错误分类是生鱼片和寿司,

%它们看起来很相似。请尝试使用此代码进行过度采样,并希望它对您的工作有所帮助。

四、运行效果

五、代码获取

后台私信回复“49期”,即可获取下载链接。

相关文章
|
13天前
|
算法 数据安全/隐私保护 计算机视觉
基于Retinex算法的图像去雾matlab仿真
本项目展示了基于Retinex算法的图像去雾技术。完整程序运行效果无水印,使用Matlab2022a开发。核心代码包含详细中文注释和操作步骤视频。Retinex理论由Edwin Land提出,旨在分离图像的光照和反射分量,增强图像对比度、颜色和细节,尤其在雾天条件下表现优异,有效解决图像去雾问题。
|
2天前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
140 65
|
24天前
|
机器学习/深度学习 数据采集 算法
基于GA遗传优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
本项目基于MATLAB2022a实现时间序列预测,采用CNN-GRU-SAM网络结构。卷积层提取局部特征,GRU层处理长期依赖,自注意力机制捕捉全局特征。完整代码含中文注释和操作视频,运行效果无水印展示。算法通过数据归一化、种群初始化、适应度计算、个体更新等步骤优化网络参数,最终输出预测结果。适用于金融市场、气象预报等领域。
基于GA遗传优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
|
14天前
|
机器学习/深度学习 监控 算法
基于yolov4深度学习网络的排队人数统计系统matlab仿真,带GUI界面
本项目基于YOLOv4深度学习网络,利用MATLAB 2022a实现排队人数统计的算法仿真。通过先进的计算机视觉技术,系统能自动、准确地检测和统计监控画面中的人数,适用于银行、车站等场景,优化资源分配和服务管理。核心程序包含多个回调函数,用于处理用户输入及界面交互,确保系统的高效运行。仿真结果无水印,操作步骤详见配套视频。
44 18
|
10天前
|
前端开发 小程序 Java
uniapp-网络数据请求全教程
这篇文档介绍了如何在uni-app项目中使用第三方包发起网络请求
29 3
|
20天前
|
机器学习/深度学习 算法 计算机视觉
基于CNN卷积神经网络的金融数据预测matlab仿真,对比BP,RBF,LSTM
本项目基于MATLAB2022A,利用CNN卷积神经网络对金融数据进行预测,并与BP、RBF和LSTM网络对比。核心程序通过处理历史价格数据,训练并测试各模型,展示预测结果及误差分析。CNN通过卷积层捕捉局部特征,BP网络学习非线性映射,RBF网络进行局部逼近,LSTM解决长序列预测中的梯度问题。实验结果表明各模型在金融数据预测中的表现差异。
|
27天前
|
算法 人机交互 数据安全/隐私保护
基于图像形态学处理和凸包分析法的指尖检测matlab仿真
本项目基于Matlab2022a实现手势识别中的指尖检测算法。测试样本展示无水印运行效果,完整代码含中文注释及操作视频。算法通过图像形态学处理和凸包检测(如Graham扫描法)来确定指尖位置,但对背景复杂度敏感,需调整参数PARA1和PARA2以优化不同手型的检测精度。
|
29天前
|
机器学习/深度学习 算法
基于遗传优化的双BP神经网络金融序列预测算法matlab仿真
本项目基于遗传优化的双BP神经网络实现金融序列预测,使用MATLAB2022A进行仿真。算法通过两个初始学习率不同的BP神经网络(e1, e2)协同工作,结合遗传算法优化,提高预测精度。实验展示了三个算法的误差对比结果,验证了该方法的有效性。
|
26天前
|
传感器 算法
基于GA遗传优化的WSN网络最优节点部署算法matlab仿真
本项目基于遗传算法(GA)优化无线传感器网络(WSN)的节点部署,旨在通过最少的节点数量实现最大覆盖。使用MATLAB2022A进行仿真,展示了不同初始节点数量(15、25、40)下的优化结果。核心程序实现了最佳解获取、节点部署绘制及适应度变化曲线展示。遗传算法通过初始化、选择、交叉和变异步骤,逐步优化节点位置配置,最终达到最优覆盖率。
|
7天前
|
机器学习/深度学习 运维 安全
深度学习在安全事件检测中的应用:守护数字世界的利器
深度学习在安全事件检测中的应用:守护数字世界的利器
56 22

热门文章

最新文章