【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型

简介: 【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型

MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型


一、基本介绍


这篇文章展示了如何使用不平衡训练数据集对图像进行分类,其中每个类的图像数量在类之间不同。两种最流行的解决方案是down-sampling降采样和over-sampling过采样。


在降采样中,每个类别的图像数量减少到所有类别中的最小图像数量。降采样的实现很容易:只需使用splitEachLabel函数并指定类的最小数量,


另一方面,当执行过采样时,每个类别的图像数量增加。这两种策略对于不平衡的数据集都是有效的。然而,过采样需要更复杂的过程。


本篇文章采用过采样平衡数据

       Label        Count
_____________    _____
caesar_salad       13 
caprese_salad       8 
french_fries       91 
greek_salad        12 
hamburger         119 
hot_dog            16 
pizza             150 
sashimi            20 
sushi              62 


过采样结果:

Label        Count
_____________    _____
caesar_salad      150 
caprese_salad     150 
french_fries      150 
greek_salad       150 
hamburger         150 
hot_dog           150 
pizza             150 
sashimi           150 
sushi             150 

二、数据情况

食品图像数据集包含九类食物的978张照片(ceaser_salad、caprese_salad,french_fries、greek_saland、汉堡包、hot_dog、披萨、生鱼片和寿司)。数据集可在下列地址下载

https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip

本文为了提高运行速度 ,选择80%训练, 10%验证,10%测试。


三、代码展示

1.导入数据

imds = imageDatastore('ExampleFoodImageDataset');

2.图像数据展示

numExample=16;
idx = randperm(numel(imds.Files),numExample);
for i=1:numExample
    I=readimage(imds,idx(i));
    I_tile{i}=insertText(I,[1,1],string(imds.Labels(idx(i))),'FontSize',20);
end
% use imtile function to tile out the example images
I_tile = imtile(I_tile);
figure()
imshow(I_tile);title('examples of the dataset')

3.数据集划分 (训练80%,验证10%,测试10%)

[imdsTrain, imdsValid,imdsTest]=splitEachLabel(imds,0.8,0.1,0.1);

4.选取最大样本数

PerClass是所有类中的最大样本数。

PerClass = max(numObservations);

5.平衡数据

randReplicateFiles是一个仅对文件进行混洗的支持功能。

要选择的图像数量由PerClass定义。从数据库中找到不同类别的图像目录,然后随机复制对应的图像至对应的数量,以平衡类中的图像数量。

files = splitapply(@(x){randReplicateFiles(x,desiredNumObservationsPerClass)},imdsTrain.Files,G);

6.构建网络

加载预先训练的模型,ResNet-18

net = resnet18;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
learnableLayer='fc1000';
classLayer='ClassificationLayer_predictions';

7.图像增强

定义图像增强器
pixelRange = [-30 30];
RotationRange = [-30 30];
scaleRange = [0.8 1.2];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange, ...
    'RandRotation',RotationRange ...
    ); 

8.设置网络参数

miniBatchSize = 64;
valFrequency = max(floor(numel(augimdsTest.Files)/miniBatchSize)*10,1);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',5, ...%30
    'InitialLearnRate',1e-2, ...%3e-4
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValid, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

9.训练网络

net = trainNetwork(augimdsTrain,lgraph,options);

10.分类评估

[YPred,probs] = classify(net,augimdsTest);
accuracy = mean(YPred == imdsTest.Labels)
YValidation = imdsTest.Labels;
YTrue=imdsTest.Labels;
figure;cm=confusionchart(YTrue,YPred);

%当我运行这个代码时,主要的错误分类是生鱼片和寿司,

%它们看起来很相似。请尝试使用此代码进行过度采样,并希望它对您的工作有所帮助。

四、运行效果

五、代码获取

后台私信回复“49期”,即可获取下载链接。

相关文章
|
2天前
|
算法
MATLAB|【免费】融合正余弦和柯西变异的麻雀优化算法SCSSA-CNN-BiLSTM双向长短期记忆网络预测模型
这段内容介绍了一个使用改进的麻雀搜索算法优化CNN-BiLSTM模型进行多输入单输出预测的程序。程序通过融合正余弦和柯西变异提升算法性能,主要优化学习率、正则化参数及BiLSTM的隐层神经元数量。它利用一段简单的风速数据进行演示,对比了改进算法与粒子群、灰狼算法的优化效果。代码包括数据导入、预处理和模型构建部分,并展示了优化前后的效果。建议使用高版本MATLAB运行。
|
4天前
|
机器学习/深度学习 人工智能 计算机视觉
深度学习之ResNet家族
ResNet是深度学习中的标志性架构,由何恺明在2016年提出,解决了深度网络训练的难题。ResNet通过残差块使得网络能有效学习,即使层数极深。后续发展包括ResNetV2,优化了信息传递和激活函数顺序;Wide Residual Networks侧重增加网络宽度而非深度;ResNeXt引入基数概念,通过多路径学习增强表示能力;Stochastic Depth通过随机丢弃层加速训练并提升泛化;DenseNet采用密集连接,增加信息交互;DPN结合ResNet和DenseNet优点;ResNeSt则综合了注意力机制、多路学习等。这些演变不断推动深度学习网络性能的提升。5月更文挑战第7天
29 7
|
4天前
|
存储 算法 数据可视化
基于harris角点和RANSAC算法的图像拼接matlab仿真
本文介绍了使用MATLAB2022a进行图像拼接的流程,涉及Harris角点检测和RANSAC算法。Harris角点检测寻找图像中局部曲率变化显著的点,RANSAC则用于排除噪声和异常点,找到最佳匹配。核心程序包括自定义的Harris角点计算函数,RANSAC参数设置,以及匹配点的可视化和仿射变换矩阵计算,最终生成全景图像。
|
4天前
|
算法 数据安全/隐私保护
matlab程序,傅里叶变换,频域数据,补零与不补零傅里叶变换
地震波格式转换、时程转换、峰值调整、规范反应谱、计算反应谱、计算持时、生成人工波、时频域转换、数据滤波、基线校正、Arias截波、傅里叶变换、耐震时程曲线、脉冲波合成与提取、三联反应谱、地震动参数、延性反应谱、地震波缩尺、功率谱密度
|
4天前
|
数据安全/隐私保护
matlab 曲线光滑,去毛刺,去离群值,数据滤波,高通滤波,低通滤波,带通滤波,带阻滤波
地震波格式转换、时程转换、峰值调整、规范反应谱、计算反应谱、计算持时、生成人工波、时频域转换、数据滤波、基线校正、Arias截波、傅里叶变换、耐震时程曲线、脉冲波合成与提取、三联反应谱、地震动参数、延性反应谱、地震波缩尺、功率谱密度
|
4天前
|
数据安全/隐私保护
时域与频域数据互相转换,傅里叶变换与逆傅里叶变换,matlab程序,时域转频域
地震波格式转换、时程转换、峰值调整、规范反应谱、计算反应谱、计算持时、生成人工波、时频域转换、数据滤波、基线校正、Arias截波、傅里叶变换、耐震时程曲线、脉冲波合成与提取、三联反应谱、地震动参数、延性反应谱、地震波缩尺、功率谱密度
|
4天前
|
机器学习/深度学习 并行计算 算法
MATLAB|【免费】概率神经网络的分类预测--基于PNN的变压器故障诊断
MATLAB|【免费】概率神经网络的分类预测--基于PNN的变压器故障诊断
|
4天前
|
机器学习/深度学习 编解码 监控
探索MATLAB在计算机视觉与深度学习领域的实战应用
探索MATLAB在计算机视觉与深度学习领域的实战应用
27 7
|
4天前
|
计算机视觉
MATLAB用Lasso回归拟合高维数据和交叉验证
MATLAB用Lasso回归拟合高维数据和交叉验证
|
4天前
|
机器学习/深度学习 存储 算法
m基于Yolov2深度学习网络的螺丝检测系统matlab仿真,带GUI界面
MATLAB 2022a中展示了YOLOv2算法的螺丝检测仿真结果,该系统基于深度学习的YOLOv2网络,有效检测和定位图像中的螺丝。YOLOv2通过批标准化、高分辨率分类器等优化实现速度和精度提升。核心代码部分涉及设置训练和测试数据,调整图像大小,加载预训练模型,构建YOLOv2网络并进行训练,最终保存检测器模型。
25 3