实战图像softmax分类模型

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: 本文是学习softmax图像分类模型的总结,主要分享softmax图像分类模型的技术原理,以及用代码实现验证,供大家参考。

本文是学习softmax图像分类模型的总结,主要分享softmax图像分类模型的技术原理,以及用代码实现验证,供大家参考。
一、图像分类问题
在日常生活中,分类问题很常见,比如下图中的动物是猫,而不是狗。人是比较很容易知道,但是要计算机知道这是猫,就需要我们训练一个图像分类模型,输入这张图片,识别结果为猫。
image.png
二、问题分析
1、任务建模
我们的目标就是训练一个图像分类模型,输入一张图片,输出一个类别。
首先先介绍一下one-hot编码,one-hot编码时一个向量,向量长度和类别一样多, 类别对应的位置设置为1,其他所有位置设置为0。比如我们需要分类的总类别数为3(即猫、狗和鸭),那么标签y=[1,0,0]表示猫,y=[0,1,0]表示狗,y=[0,0,1]鸭。
我们用线性回归模型来实现图像分类问题,那么整个任务可以拆解为如下流程:
image.png
在整个流程中,主要与前期线性回归模型不同的地方有三处:
(1)输入是一张图片,需要把图片转为一维行向量,然后作为输入。
(2)线性回归模式是一个多输出模型,即一个样本输入,输出有多个(输出个数与类别总数相等);
(3)需要把多个输出转换为对应的标签类别。
下面重点说明如何把多个输出转换为对应的标签类别,比如还是之前(猫、狗、鸭)分类问题,假设一个样本经过线性回归模型之后,得到三个输出分别为Out(1)=2,Out(2)=4,Out(3)=6,则输出向量为(2,4,6)。因为我们使用的是one-hot编码,每个类别真实标签向量的分量都是0-1之间的数值,为使输出标签向量的值变换到0-1之间,在分类问题中常用softmax函数来进行处理:
image.png
上述输出向量为(2,4,6)经过softmax变换之后,得到的输出向量为(0.0159, 0.1173, 0.8668),该向量表示图片是猫的概率为0.0159,是狗的概率为0.1173,是鸭的概率为0.8668,我们取向量中的最大值作为分类结果,即输出向量(0.0159, 0.1173, 0.8668)的分类结果为鸭。
2、损失函数
在线性回归模型中,我们用均方误差作为损失函数,但是在分类问题中,一般使用交叉熵来作为损失函数,交叉熵函数用来衡量两个概率的区别,其定义如下:
image.png
由上述分类任务建模分析可知,预测值和真实值表示某个类别的概率,所以每个样本预测值与真实值之间的损失函数为:
image.png
因为真实值y中,只有一个分量为1,其他都为0,上述损失函数可以化简为
image.png
比如还是上述例子:假设真实值y=(0,0,1),预测值(0.0159, 0.1173, 0.8668),
image.png
因此,在训练多输出线性回归模型时,我们希望寻找一组参数(W,b),使得L(W,b)在所有训练样本上的损失均值越小越好。
3、模型评估
在分类问题中,我们希望整个模型的分类准确度越高越好,分类准确度为正确预测数量与总预测数量之比。
三、代码验证
整个代码验证过程包括如下主要流程:
image.png
1、获取数据
我们选取Fashion‐MNIST图像分类数据集来进行验证,Fashion‐MNIST由10个类别[分别为t‐shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴))]的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。
每张图像为灰度图像,通道数为1,,图像的高度和宽度均为28像素。
image.png
我们可以查看train_iter和test_iter中的数据。
image.png
2、定义模型
由前面的分析可知,整个分类模型分为两个层,首先要把图像转为一维向量,然后在输入到线性回归模型中。Fashion‐MNIST数据集中的每个样本都是28×28的图像,将其展平转换为784的向量。类别为10,所以模型输出维度为10。
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
然后初始化模型参数
image.png
3、定义损失函数
在pytorch中有已定义好的交叉熵损失函数可以直接使用。
image.png
4、定义优化算法
我们采用随机梯度下降法,来迭代更新权重参数,可直接使用pytorch中已定义好的函数。
image.png
5、定义分类准确度
分类准确度为正确预测数量与总预测数量之比。
image.png
image.png
image.png
6、训练
image.png
运行得到结果
image.png
7、预测
将训练得到的模型在测试集进行预测推理
image.png
结果如下所示:
image.png
至此,softmax分类模型完毕。

参考资料

1、《动手学深度学习》第二版,地址: zh.d2l.ai/index.html

目录
相关文章
|
人工智能 数据可视化 数据处理
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
778 0
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
|
6月前
|
机器学习/深度学习 存储 数据可视化
MambaOut:状态空间模型并不适合图像的分类任务
该论文研究了Mamba架构(含状态空间模型SSM)在视觉任务(图像分类、目标检测、语义分割)中的必要性。实验表明,Mamba在这些任务中效果不如传统卷积和注意力模型。论文提出,SSM更适合长序列和自回归任务,而非视觉任务。MambaOut(不带SSM的门控CNN块)在图像分类上优于视觉Mamba,但在检测和分割任务中略逊一筹,暗示SSM在这类任务中可能仍有价值。研究还探讨了Mamba在处理长序列任务时的效率和局部信息整合能力。尽管整体表现一般,但论文为优化不同视觉任务的模型架构提供了新视角。
109 2
|
6月前
|
机器学习/深度学习 算法 TensorFlow
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
|
6月前
|
SQL 数据可视化 数据挖掘
R语言线性分类判别LDA和二次分类判别QDA实例
R语言线性分类判别LDA和二次分类判别QDA实例
|
6月前
|
机器学习/深度学习 数据采集 PyTorch
PyTorch使用神经网络进行手写数字识别实战(附源码,包括损失图像和准确率图像)
PyTorch使用神经网络进行手写数字识别实战(附源码,包括损失图像和准确率图像)
149 0
|
机器学习/深度学习 算法 数据可视化
基于线性SVM的CIFAR-10图像集分类
基于线性SVM的CIFAR-10图像集分类
746 0
基于线性SVM的CIFAR-10图像集分类
|
机器学习/深度学习 文字识别 监控
使用 HOG 功能和多类 SVM 分类器对数字进行分类
使用 HOG 功能和多类 SVM 分类器对数字进行分类。
142 0
|
机器学习/深度学习 人工智能 数据可视化
【Pytorch神经网络实战案例】22 基于Cora数据集实现图注意力神经网络GAT的论文分类
有一个记录论文信息的数据集,数据集里面含有每一篇论文的关键词以及分类信息,同时还有论文间互相引用的信息。搭建AI模型,对数据集中的论文信息进行分析,使模型学习已有论文的分类特征,以便预测出未知分类的论文类别。
470 0
|
机器学习/深度学习 人工智能 PyTorch
【Pytorch神经网络理论篇】 30 图片分类模型:Inception模型
原始的Inception模型采用多分支结构(见图1-1),它将1×1卷积、3×3卷积最大池化堆叠在一起。这种结构既可以增加网络的宽度,又可以增强网络对不同尺寸的适应性。
367 0
|
机器学习/深度学习 人工智能 算法
【Pytorch神经网络理论篇】 31 图片分类模型:ResNet模型+DenseNet模型+EffcientNet模型
在深度学习领域中,模型越深意味着拟合能力越强,出现过拟合问题是正常的,训练误差越来越大却是不正常的。
361 0

热门文章

最新文章