CNN-样本不平衡

简介:

卷积神经网络(CNN)可以说是目前处理图像最有力的工具了。而在机器学习分类问题中,样本不平衡又是一个经常遇到的问题。最近在使用CNN进行图片分类时,发现CNN对训练集样本不平衡问题很敏感。在网上搜索了一下,发现[这篇文章
](http%3A//www.diva-portal.org/smash/get/diva2%3A811111/FULLTEXT01.pdf)对这个问题已经做了比较细致的探索。于是就把它简单整理了一下,相关的记录如下。一、实验数据与使用的网络
所谓样本不平衡,就是指在分类问题中,每一类对应的样本的个数不同,而且差别较大。这样的不平衡的样本往往使机器学习算法的表现变得比较差。那么在CNN中又有什么样的影响呢?作者选用了CIFAR-10作为数据源来生成不平衡的样本数据。
CIFAR-10是一个简单的图像分类数据集。共有10类(airplane,automobile,bird,cat,deer,dog, frog,horse,ship,truck),每一类含有5000张训练图片,1000张测试图片。
CIFAR-10样例如图:
训练时,选择的网络是[这里*
](http%3A//code.google.com/p/cuda-convnet/)的CIFAR-10训练网络和参数(来自Alex Krizhevsky)。这个网络含有3个卷积层,还有10个输出结点。
之所以不选用效果更好的CNN网络,是因为我们的目的是在实验时训练很多次进行比较,而不是获得多么好的性能。而这个CNN网络因为比较浅,训练速度比较快,比较符合我们的要求。
二、类别不平衡数据的生成
直接从原始CIFAR-10采样,通过控制每一类采样的个数,就可以产生类别不平衡的训练数据。如下表所示:这里的每一行就表示“一份”训练数据。而每个数字就表示这个类别占这“一份”训练数据的百分比。
Dist. 1:类别平衡,每一类都占用10%的数据。
Dist. 2、Dist. 3:一部分类别的数据比另一部分多。
Dist. 4、Dist 5:只有一类数据比较多。
Dist. 6、Dist 7:只有一类数据比较少。
Dist. 8: 数据个数呈线性分布。
Dist. 9:数据个数呈指数级分布。
Dist. 10、Dist. 11:交通工具对应的类别中的样本数都比动物的多
对每一份训练数据都进行训练,测试时用的测试集还是每类1000个的原始测试集,保持不变。
三、类别不平衡数据的训练结果
以上数据经过训练后,每一类对应的预测正确率如下:
第一列Total表示总的正确率,下面是每一类分别的正确率。
从实验结果中可以看出:
类别完全平衡时,结果最好。
类别“越不平衡”,效果越差。比如Dist. 3就比Dist. 2更不平衡,效果就更差。同样的对比还有Dist. 4和Dist. 5,Dist. 8和Dist. 9。其中Dist. 5和Dist. 9更是完全训练失败了。

四、过采样训练的结果
作者还实验了“过采样”(oversampling)这种平衡数据集的方法。这里的过采样方法是:对每一份数据集中比较少的类,直接复制其中的图片增大样本数量直至所有类别平衡。
再次训练,进行测试,结果为:
可以发现过采样的效果非常好,基本与平衡时候的表现一样了。
过采样前后效果对比,可以发现过采样效果非常好:
五、总结
CNN确实对训练样本中类别不平衡的问题很敏感。平衡的类别往往能获得最佳的表现,而不平衡的类别往往使模型的效果下降。如果训练样本不平衡,可以使用过采样平衡样本之后再训练。
这确实是一个“经验主义”的结论,但多少给我们平常训练CNN模型带来一些启发和帮助。

目录
相关文章
|
定位技术
ArcGIS地形起伏度+地形粗糙度+地表切割深度+高程变异系数提取
ArcGIS地形起伏度+地形粗糙度+地表切割深度+高程变异系数提取
16869 0
|
SQL 数据采集 关系型数据库
大数据采集和抽取怎么做?这篇文章终于说明白了!
数据是数据中台\数据平台核心中的核心,因此数据汇聚必然是数据中台/平台的入口,本文详细讲述采集模块的方方面面、采集框架的使用选型以及企业真实落地
大数据采集和抽取怎么做?这篇文章终于说明白了!
|
11月前
|
前端开发 Java 数据库连接
Java后端开发-使用springboot进行Mybatis连接数据库步骤
本文介绍了使用Java和IDEA进行数据库操作的详细步骤,涵盖从数据库准备到测试类编写及运行的全过程。主要内容包括: 1. **数据库准备**:创建数据库和表。 2. **查询数据库**:验证数据库是否可用。 3. **IDEA代码配置**:构建实体类并配置数据库连接。 4. **测试类编写**:编写并运行测试类以确保一切正常。
539 2
|
敏捷开发 测试技术 uml
UML 在敏捷开发中的应用与实践
【8月更文第23天】统一建模语言 (UML) 是一种广泛使用的图形化语言,用于描述软件系统的设计。它通过各种图表和符号来帮助开发团队理解系统的架构、行为和交互。而敏捷开发则是一种强调快速迭代、客户反馈和持续改进的软件开发方法论。这两种看似风格迥异的方法实际上可以很好地协同工作,以提高软件项目的效率和质量。
390 4
|
机器学习/深度学习 数据采集 数据可视化
使用Python实现深度学习模型:智能交通信号优化
使用Python实现深度学习模型:智能交通信号优化
630 9
|
存储 消息中间件 分布式计算
分布式系统详解--基础知识(概论)
分布式系统详解--基础知识(概论)
484 0
|
数据可视化
rpm 的降级安装命令是什么?
【6月更文挑战第13天】rpm 的降级安装命令是什么?
816 2
|
机器学习/深度学习 人工智能 自然语言处理
OpenAI 推出 GPT-4o,免费向所有人提供GPT-4级别的AI ,可以实时对音频、视觉和文本进行推理,附使用详细指南
GPT-4o不仅提供与GPT-4同等程度的模型能力,推理速度还更快,还能提供同时理解文本、图像、音频等内容的多模态能力,无论你是付费用户,还是免费用户,都能通过它体验GPT-4了
912 1
|
数据处理
InVEST模型的下载及入门操作(以InVEST3.13.0为例)
InVEST是一套免费的开源软件模型,是美国自然资本项目组开发的、用于评估生态系统服务功能量及其经济价值、支持生态系统管理和决策的一套模型系统,用于绘制和评估维持和实现人类生活的自然商品和服务。包括商品生产(如食物)、生命维持过程(如水净化)和充实生命的条件(如美丽、娱乐机会)以及选择的保护(如未来使用的遗传多样性)等模块。(翻译自模型官网)
2907 1