Densenet-Tensorflow

简介: 在寻找densnet网络的时候,我发现了一个结构清晰完整的网络代码,在此作备份。https://github.com/taki0112/Densenet-TensorflowDensenet-TensorflowTensorflow implementation of Densenet usi...

在寻找densnet网络的时候,我发现了一个结构清晰完整的网络代码,在此作备份。

https://github.com/taki0112/Densenet-Tensorflow

Densenet-Tensorflow

Tensorflow implementation of Densenet using Cifar10, MNIST

  • The code that implements this paper is Densenet.py
  • There is a slight difference, I used AdamOptimizer

If you want to see the original author's code or other implementations, please refer to this link

 

Requirements

  • Tensorflow 1.x
  • Python 3.x
  • tflearn (If you are easy to use global average pooling, you should install tflearn
However, I implemented it using tf.layers, so don't worry

Issue

  • I used tf.contrib.layers.batch_norm
  def Batch_Normalization(x, training, scope):
        with arg_scope([batch_norm],
                       scope=scope,
                       updates_collections=None,
                       decay=0.9,
                       center=True,
                       scale=True,
                       zero_debias_moving_mean=True) :
            return tf.cond(training,
                           lambda : batch_norm(inputs=x, is_training=training, reuse=None),
                           lambda : batch_norm(inputs=x, is_training=training, reuse=True))

 

  • If not enough GPU memory, Please edit the code
with tf.Session() as sess : NO
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess : OK

Idea

What is the "Global Average Pooling" ?

    def Global_Average_Pooling(x, stride=1) :
        width = np.shape(x)[1]
        height = np.shape(x)[2]
        pool_size = [width, height]
        return tf.layers.average_pooling2d(inputs=x, pool_size=pool_size, strides=stride) 
        # The stride value does not matter
If you use tflearn, please refer to this link
    def Global_Average_Pooling(x):
        return tflearn.layers.conv.global_avg_pool(x, name='Global_avg_pooling')

 

What is the "Dense Connectivity" ?

Dense_connectivity

What is the "Densenet Architecture" ?

Dense_Architecture

    def Dense_net(self, input_x):
        x = conv_layer(input_x, filter=2 * self.filters, kernel=[7,7], stride=2, layer_name='conv0')
        x = Max_Pooling(x, pool_size=[3,3], stride=2)

        x = self.dense_block(input_x=x, nb_layers=6, layer_name='dense_1')
        x = self.transition_layer(x, scope='trans_1')

        x = self.dense_block(input_x=x, nb_layers=12, layer_name='dense_2')
        x = self.transition_layer(x, scope='trans_2')

        x = self.dense_block(input_x=x, nb_layers=48, layer_name='dense_3')
        x = self.transition_layer(x, scope='trans_3')

        x = self.dense_block(input_x=x, nb_layers=32, layer_name='dense_final') 
        
        x = Batch_Normalization(x, training=self.training, scope='linear_batch')
        x = Relu(x)
        x = Global_Average_Pooling(x)
        x = Linear(x)

        return x

 

What is the "Dense Block" ?

Dense_block

   def dense_block(self, input_x, nb_layers, layer_name):
        with tf.name_scope(layer_name):
            layers_concat = list()
            layers_concat.append(input_x)

            x = self.bottleneck_layer(input_x, scope=layer_name + '_bottleN_' + str(0))

            layers_concat.append(x)

            for i in range(nb_layers - 1):
                x = Concatenation(layers_concat)
                x = self.bottleneck_layer(x, scope=layer_name + '_bottleN_' + str(i + 1))
                layers_concat.append(x)

            return x

 

What is the "Bottleneck Layer" ?

 def bottleneck_layer(self, x, scope):
        with tf.name_scope(scope):
            x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1')
            x = Relu(x)
            x = conv_layer(x, filter=4 * self.filters, kernel=[1,1], layer_name=scope+'_conv1')
            x = Drop_out(x, rate=dropout_rate, training=self.training)

            x = Batch_Normalization(x, training=self.training, scope=scope+'_batch2')
            x = Relu(x)
            x = conv_layer(x, filter=self.filters, kernel=[3,3], layer_name=scope+'_conv2')
            x = Drop_out(x, rate=dropout_rate, training=self.training)
            
            return x

 

What is the "Transition Layer" ?

    def transition_layer(self, x, scope):
        with tf.name_scope(scope):
            x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1')
            x = Relu(x)
            x = conv_layer(x, filter=self.filters, kernel=[1,1], layer_name=scope+'_conv1')
            x = Drop_out(x, rate=dropout_rate, training=self.training)
            x = Average_pooling(x, pool_size=[2,2], stride=2)

            return x

 

Compare Structure (CNN, ResNet, DenseNet)

compare

Results

  • (MNIST) The highest test accuracy is 99.2% (This result does not use dropout)
  • The number of dense block layers is fixed to 4
    for i in range(self.nb_blocks) :
        # original : 6 -> 12 -> 48

        x = self.dense_block(input_x=x, nb_layers=4, layer_name='dense_'+str(i))
        x = self.transition_layer(x, scope='trans_'+str(i))

 

CIFAR-10

cifar_10

CIFAR-100

cifar_100

Image Net

image_net

Related works

References

Author

Junho Kim

目录
相关文章
|
机器学习/深度学习 自然语言处理 监控
ms-swift 部分命令行参数说明
本资源介绍了机器学习训练中的关键参数设置及其影响,包括训练轮数、批量大小、学习率、梯度累积、模型微调等,并提供了针对不同任务和硬件配置的推荐值,帮助提升模型训练效率与性能。
1473 4
|
机器学习/深度学习 PyTorch 算法框架/工具
飞桨x昇腾生态适配方案:00_整体方案介绍
本文详细介绍PaddlePaddle与NPU的适配工作,涵盖训练与推理支持、性能优化及离线推理方案。PaddleCustomDevice作为适配层,支持主流模型(详见飞桨-昇腾模型列表),多数性能媲美V100,部分调优模型接近0.8*A800。硬件适配主要针对A2芯片,A1兼容但310系列建议离线推理。提供常用模型仓链接及整体方案导览,包括环境准备、算子适配、性能调优和Paddle转ONNX/OM等内容。
963 0
|
算法 数据安全/隐私保护
基于正则化算法的SAR图像去噪matlab仿真
本课题基于MATLAB 2022a实现SAR图像去噪仿真,采用正则化算法有效抑制噪声并保留图像细节。程序包括正则化处理、门限提取、迭代曲线绘制及PSNR分析。通过调整正则化参数,对比不同噪声干扰下的去噪效果,输出SAR图像去噪结果与性能指标。正则化参数选择方法涵盖经验法、交叉验证法及理论分析法,确保去噪效果最优。系统运行结果清晰,无水印,适用于军事侦察、地形测绘等领域。
|
编解码 人工智能 文字识别
阶跃星辰开源GOT-OCR2.0:统一端到端模型,魔搭一站式推理微调最佳实践来啦!
GOT来促进OCR-2.0的到来。该模型具有580百万参数,是一个统一、优雅和端到端的模型,由高压缩编码器和长上下文解码器组成。
阶跃星辰开源GOT-OCR2.0:统一端到端模型,魔搭一站式推理微调最佳实践来啦!
|
安全 Oracle Java
【面试题精讲】Java 和 C++ 的区别?
【面试题精讲】Java 和 C++ 的区别?
|
人工智能 安全 Linux
Python常用镜像源
Python常用镜像源
5962 0
什么叫灰度测试
灰度测试是什么意思呢?如果对互联网软件研发行业不太了解的话,可能对这个词还是很陌生的,其实灰度测试就是指如果软件要在不久的将来推出一个全新的功能,或者做一次比较重大的改版的话,要先进行一个小范围的尝试工作,然后再慢慢放量,直到这个全新的功能覆盖到所有的系统用户,也就是说在新功能上线的黑白之间有一个灰,所以这种方法也通常被称为灰度测试。
12162 1
|
Ubuntu 关系型数据库 MySQL
在使用apt-get install XXX,报E: Unable to locate package XXX
一、在使用apt-get install XXX安装某个软件的时候,经常会出现一个错 例如以mysql为例: root@iZ2zeht3zvxbq5ycy698pwZ:~# apt-get install mysql-server Reading package lists.
10312 0