✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。
🍎个人主页:Matlab科研工作室
🍊个人信条:格物致知。
更多Matlab仿真内容点击👇
⛄ 内容介绍
在图片搜索应用中,与基于时间,地点或关键词查找等成熟的搜索方式相比,基于图片内容的搜索具有更加直观快捷的特点.因此,利用图像分类技术实现基于图像内容的搜索成为一个日益受到关注的研究领域.一般地,机器学习中的分类算法可以分为两类:基于人为构造特征的分类方法,如K近邻算法,支持向量机等;以及以卷积神经网络为基础的,不需要人为设计特征的深度学习方法.
⛄ 部分代码
%% Flower Classifier using a CNN and data augmentation
%% PART 1: Baseline Classifier
%% Create image data store
imds = imageDatastore(fullfile('Flowers'),...
'IncludeSubfolders',true,'FileExtensions','.jpg','LabelSource','foldernames');
% Count number of images per label and save the number of classes
labelCount = countEachLabel(imds);
numClasses = height(labelCount);
%% Create training and validation sets
[imdsTrainingSet, imdsValidationSet] = splitEachLabel(imds, 0.7, 'randomize');
%% Build a simple CNN
imageSize = [227 227 3];
% Specify the convolutional neural network architecture.
layers = [
imageInputLayer(imageSize)
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(4)
softmaxLayer
classificationLayer];
%% Specify training options
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidationSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%% Train the network
net1 = trainNetwork(imdsTrainingSet,layers,options);
%% Report accuracy of baseline classifier on validation set
YPred = classify(net1,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
imdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
%% PART 2: Baseline Classifier with Data Augmentation
%% Create augmented image data store
% Specify data augmentation options and values/ranges
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-20,20], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
% Apply transformations (using randomly picked values) and build augmented
% data store
augImds = augmentedImageDatastore(imageSize,imdsTrainingSet, ...
'DataAugmentation',imageAugmenter);
% (OPTIONAL) Preview augmentation results
batchedData = preview(augImds);
figure, imshow(imtile(batchedData.input))
%% Train the network.
net2 = trainNetwork(augImds,layers,options);
%% Report accuracy of baseline classifier with image data augmentation
YPred = classify(net2,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
augImdsAccuracy = sum(YPred == YValidation)/numel(YValidation);
%% Plot confusion matrix
figure, plotconfusion(YValidation,YPred)
%% PART 3: Transfer Learning without Data Augmentation
%% Load pretrained AlexNet
net = alexnet;
%% Replace final layers
layersTransfer = net.Layers(1:end-3);
layers = [
layersTransfer
fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
softmaxLayer
classificationLayer];
%% Train network
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidationSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer1 = trainNetwork(imdsTrainingSet,layers,options);
%% Compute accuracy and plot confusion matrix
YPred = classify(netTransfer1,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
netTransfer1BaselineAccuracy = sum(YPred == YValidation)/numel(YValidation);
figure, plotconfusion(YValidation,YPred)
%% PART 4: Transfer Learning with Data Augmentation
%% Train network
netTransfer2 = trainNetwork(augImds,layers,options);
%% Compute accuracy and plot confusion matrix
YPred = classify(netTransfer2,imdsValidationSet);
YValidation = imdsValidationSet.Labels;
netTransfer2BaselineAccuracy = sum(YPred == YValidation)/numel(YValidation);
figure, plotconfusion(YValidation,YPred)
⛄ 运行结果
⛄ 参考文献
[1]李晓普. 基于卷积神经网络的图像分类[D]. 大连理工大学, 2015.