【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型

简介: 【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型

MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型


一、基本介绍


这篇文章展示了如何使用不平衡训练数据集对图像进行分类,其中每个类的图像数量在类之间不同。两种最流行的解决方案是down-sampling降采样和over-sampling过采样。


在降采样中,每个类别的图像数量减少到所有类别中的最小图像数量。降采样的实现很容易:只需使用splitEachLabel函数并指定类的最小数量,


另一方面,当执行过采样时,每个类别的图像数量增加。这两种策略对于不平衡的数据集都是有效的。然而,过采样需要更复杂的过程。


本篇文章采用过采样平衡数据

       Label        Count
_____________    _____
caesar_salad       13 
caprese_salad       8 
french_fries       91 
greek_salad        12 
hamburger         119 
hot_dog            16 
pizza             150 
sashimi            20 
sushi              62 


过采样结果:

Label        Count
_____________    _____
caesar_salad      150 
caprese_salad     150 
french_fries      150 
greek_salad       150 
hamburger         150 
hot_dog           150 
pizza             150 
sashimi           150 
sushi             150 

二、数据情况

食品图像数据集包含九类食物的978张照片(ceaser_salad、caprese_salad,french_fries、greek_saland、汉堡包、hot_dog、披萨、生鱼片和寿司)。数据集可在下列地址下载

https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip

本文为了提高运行速度 ,选择80%训练, 10%验证,10%测试。


三、代码展示

1.导入数据

imds = imageDatastore('ExampleFoodImageDataset');

2.图像数据展示

numExample=16;
idx = randperm(numel(imds.Files),numExample);
for i=1:numExample
    I=readimage(imds,idx(i));
    I_tile{i}=insertText(I,[1,1],string(imds.Labels(idx(i))),'FontSize',20);
end
% use imtile function to tile out the example images
I_tile = imtile(I_tile);
figure()
imshow(I_tile);title('examples of the dataset')

3.数据集划分 (训练80%,验证10%,测试10%)

[imdsTrain, imdsValid,imdsTest]=splitEachLabel(imds,0.8,0.1,0.1);

4.选取最大样本数

PerClass是所有类中的最大样本数。

PerClass = max(numObservations);

5.平衡数据

randReplicateFiles是一个仅对文件进行混洗的支持功能。

要选择的图像数量由PerClass定义。从数据库中找到不同类别的图像目录,然后随机复制对应的图像至对应的数量,以平衡类中的图像数量。

files = splitapply(@(x){randReplicateFiles(x,desiredNumObservationsPerClass)},imdsTrain.Files,G);

6.构建网络

加载预先训练的模型,ResNet-18

net = resnet18;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
learnableLayer='fc1000';
classLayer='ClassificationLayer_predictions';

7.图像增强

定义图像增强器
pixelRange = [-30 30];
RotationRange = [-30 30];
scaleRange = [0.8 1.2];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange, ...
    'RandRotation',RotationRange ...
    ); 

8.设置网络参数

miniBatchSize = 64;
valFrequency = max(floor(numel(augimdsTest.Files)/miniBatchSize)*10,1);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',5, ...%30
    'InitialLearnRate',1e-2, ...%3e-4
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValid, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

9.训练网络

net = trainNetwork(augimdsTrain,lgraph,options);

10.分类评估

[YPred,probs] = classify(net,augimdsTest);
accuracy = mean(YPred == imdsTest.Labels)
YValidation = imdsTest.Labels;
YTrue=imdsTest.Labels;
figure;cm=confusionchart(YTrue,YPred);

%当我运行这个代码时,主要的错误分类是生鱼片和寿司,

%它们看起来很相似。请尝试使用此代码进行过度采样,并希望它对您的工作有所帮助。

四、运行效果

五、代码获取

后台私信回复“49期”,即可获取下载链接。

相关文章
|
2月前
|
机器学习/深度学习 算法 机器人
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
245 0
|
3月前
|
机器学习/深度学习 编解码 并行计算
【改进引导滤波器】各向异性引导滤波器,利用加权平均来实现最大扩散,同时保持图像中的强边缘,实现强各向异性滤波,同时保持原始引导滤波器的低低计算成本(Matlab代码实现)
【改进引导滤波器】各向异性引导滤波器,利用加权平均来实现最大扩散,同时保持图像中的强边缘,实现强各向异性滤波,同时保持原始引导滤波器的低低计算成本(Matlab代码实现)
213 8
|
2月前
|
算法 定位技术 计算机视觉
【水下图像增强】基于波长补偿与去雾的水下图像增强研究(Matlab代码实现)
【水下图像增强】基于波长补偿与去雾的水下图像增强研究(Matlab代码实现)
121 0
|
2月前
|
算法 机器人 计算机视觉
【图像处理】水下图像增强的颜色平衡与融合技术研究(Matlab代码实现)
【图像处理】水下图像增强的颜色平衡与融合技术研究(Matlab代码实现)
109 0
|
2月前
|
机器学习/深度学习 算法 自动驾驶
基于导向滤波的暗通道去雾算法在灰度与彩色图像可见度复原中的研究(Matlab代码实现)
基于导向滤波的暗通道去雾算法在灰度与彩色图像可见度复原中的研究(Matlab代码实现)
172 8
|
2月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
3月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
230 2
|
3月前
|
机器学习/深度学习 算法 安全
【图像处理】使用四树分割和直方图移动的可逆图像数据隐藏(Matlab代码实现)
【图像处理】使用四树分割和直方图移动的可逆图像数据隐藏(Matlab代码实现)
178 2
|
3月前
|
机器学习/深度学习 并行计算 算法
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
|
11月前
|
机器学习/深度学习 运维 安全
深度学习在安全事件检测中的应用:守护数字世界的利器
深度学习在安全事件检测中的应用:守护数字世界的利器
432 22