DL之Attention:基于ClutteredMNIST手写数字图片数据集分别利用CNN_Init、ST_CNN算法(CNN+SpatialTransformer)实现多分类预测(二)

简介: DL之Attention:基于ClutteredMNIST手写数字图片数据集分别利用CNN_Init、ST_CNN算法(CNN+SpatialTransformer)实现多分类预测

image.png

image.png


核心代码

   #(2)、建立ST定位网络:尝试更多的conv层,并分别在X轴和y轴上做最大池化

   # localization net. TODO: try more conv layers, and do max pooling on X- and Y-axes respectively

   locnet = Sequential()

   # locnet.add(MaxPooling2D(pool_size=(2,2), input_shape=input_shape))

   # locnet.add(Convolution2D(32, (5, 5)))

   locnet.add(Convolution2D(32, (5, 5), input_shape=input_shape))

   locnet.add(Activation('relu'))

   # locnet.add(Dropout(0.2)) # 0.2

   locnet.add(MaxPooling2D(pool_size=(2,2)))

   locnet.add(Convolution2D(64, (5, 5)))

   locnet.add(Activation('relu'))

   # locnet.add(Dropout(0.2)) # 0.3

   locnet.add(Convolution2D(64, (3, 3)))

   locnet.add(Activation('relu'))

   locnet.add(MaxPooling2D(pool_size=(2,2)))

 

   locnet.add(Flatten())

   locnet.add(Dense(50))

   locnet.add(Activation('relu'))

   locnet.add(Dense(6, weights=weights))

   print(locnet.summary())

 

 

   #(3)、建立CNN网络

   model = Sequential()

   model.add(SpatialTransformer(localization_net=locnet,

                                output_size=(30,30), input_shape=input_shape))

   # model.add(Convolution2D(32, (3, 3), padding='same'))

   # model.add(Activation('relu'))

   # model.add(MaxPooling2D(pool_size=(2, 2)))

   # model.add(Convolution2D(64, (3, 3)))

   # model.add(Activation('relu'))

   # model.add(MaxPooling2D(pool_size=(2, 2)))

   # model.add(Dropout(0.5)) # 0.25

 

   # E: removed first 3 dropout layers

   model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))

   model.add(Dropout(0.5)) # 0.5

   model.add(Conv2D(64, (3, 3), activation='relu'))

   model.add(Dropout(0.5)) # 0.5

   model.add(MaxPooling2D(pool_size=(2, 2)))

   model.add(Conv2D(64, kernel_size=(3, 3),

                    activation='relu'))

   model.add(Dropout(0.5)) # 0.5

   model.add(MaxPooling2D(pool_size=(2, 2)))

   # model.add(Conv2D(64, (3, 3), activation='relu'))

   # model.add(Dropout(0.5))

   model.add(Flatten())

   model.add(Dense(256)) # 256

   model.add(Dropout(0.5)) # 0.5

   model.add(Activation('relu'))

   model.add(Dense(nb_classes))

   model.add(Activation('softmax'))


相关文章
|
25天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
1月前
|
算法 搜索推荐 Java
java 后端 使用 Graphics2D 制作海报,画echarts图,带工具类,各种细节:如头像切割成圆形,文字换行算法(完美实验success),解决画上文字、图片后不清晰问题
这篇文章介绍了如何使用Java后端技术,结合Graphics2D和Echarts等工具,生成包含个性化信息和图表的海报,并提供了详细的代码实现和GitHub项目链接。
105 0
java 后端 使用 Graphics2D 制作海报,画echarts图,带工具类,各种细节:如头像切割成圆形,文字换行算法(完美实验success),解决画上文字、图片后不清晰问题
|
2月前
|
机器学习/深度学习 人工智能 算法
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集("体育类", "财经类", "房产类", "家居类", "教育类", "科技类", "时尚类", "时政类", "游戏类", "娱乐类"),然后基于TensorFlow搭建CNN卷积神经网络算法模型。通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型,并保存为本地的h5格式。然后使用Django开发Web网页端操作界面,实现用户上传一段文本识别其所属的类别。
91 1
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
|
1月前
|
存储 缓存 分布式计算
数据结构与算法学习一:学习前的准备,数据结构的分类,数据结构与算法的关系,实际编程中遇到的问题,几个经典算法问题
这篇文章是关于数据结构与算法的学习指南,涵盖了数据结构的分类、数据结构与算法的关系、实际编程中遇到的问题以及几个经典的算法面试题。
30 0
数据结构与算法学习一:学习前的准备,数据结构的分类,数据结构与算法的关系,实际编程中遇到的问题,几个经典算法问题
|
29天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化卷积神经网络(Bayes-CNN)的多因子数据分类识别算法matlab仿真
本项目展示了贝叶斯优化在CNN中的应用,包括优化过程、训练与识别效果对比,以及标准CNN的识别结果。使用Matlab2022a开发,提供完整代码及视频教程。贝叶斯优化通过构建代理模型指导超参数优化,显著提升模型性能,适用于复杂数据分类任务。
|
1月前
|
移动开发 算法 前端开发
前端常用算法全解:特征梳理、复杂度比较、分类解读与示例展示
前端常用算法全解:特征梳理、复杂度比较、分类解读与示例展示
21 0
|
1月前
|
算法 Java Linux
java制作海报一:java使用Graphics2D 在图片上写字,文字换行算法详解
这篇文章介绍了如何在Java中使用Graphics2D在图片上绘制文字,并实现自动换行的功能。
98 0
|
2月前
|
机器学习/深度学习 算法 数据挖掘
决策树算法大揭秘:Python让你秒懂分支逻辑,精准分类不再难
【9月更文挑战第12天】决策树算法作为机器学习领域的一颗明珠,凭借其直观易懂和强大的解释能力,在分类与回归任务中表现出色。相比传统统计方法,决策树通过简单的分支逻辑实现了数据的精准分类。本文将借助Python和scikit-learn库,以鸢尾花数据集为例,展示如何使用决策树进行分类,并探讨其优势与局限。通过构建一系列条件判断,决策树不仅模拟了人类决策过程,还确保了结果的可追溯性和可解释性。无论您是新手还是专家,都能轻松上手,享受机器学习的乐趣。
48 9
|
3月前
|
数据采集 机器学习/深度学习 算法
【python】python客户信息审计风险决策树算法分类预测(源码+数据集+论文)【独一无二】
【python】python客户信息审计风险决策树算法分类预测(源码+数据集+论文)【独一无二】
|
3月前
|
算法 5G Windows
OFDM系统中的信号检测算法分类和详解
参考文献 [1]周健, 张冬. MIMO-OFDM系统中的信号检测算法(I)[J]. 南京工程学院学报(自然科学版), 2010. [2]王华龙.MIMO-OFDM系统传统信号检测算法[J].科技创新与应用,2016(23):63.
77 4

热门文章

最新文章