深度残差收缩网络(6):代码实现

简介:

深度残差收缩网络其实是一种通用的特征学习方法,是深度残差网络ResNet、注意力机制和软阈值化的集成,可以用于图像分类。本文采用TensorFlow 1.0和TFLearn 0.3.2,编写了图像分类的程序,采用的图像数据为CIFAR-10。CIFAR-10是一个非常常用的图像数据集,包含10个类别的图像。可以在这个网址找到具体介绍:https://www.cs.toronto.edu/~kriz/cifar.html

2

参照ResNet代码(https://github.com/tflearn/tflearn/blob/master/examples/images/residual_network_cifar10.py),所编写的深度残差收缩网络的代码如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
 
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
 
@author: super_9527
"""
  
from __future__ import division, print_function, absolute_import
  
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
  
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
  
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
  
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
  
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
      
    # residual shrinkage blocks with channel-wise thresholds
  
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
  
    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
  
    with vscope as scope:
        name = scope.name #TODO
  
        for i in range(nb_blocks):
  
            identity = residual
  
            if not downsample:
                downsample_strides = 1
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
              
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
              
  
            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)
  
            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels
  
            residual = residual + identity
  
    return residual
  
  
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
  
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
  
# Building Deep Residual Shrinkage Network
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)
  
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
  
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

上面的代码构建了一个小型的深度残差收缩网络,只含有3个基本残差收缩模块,其他的超参数也未进行优化。如果为了追求更高的准确率的话,可以适当增加深度,增加训练迭代次数,以及适当调整超参数。

转载网址:
深度残差收缩网络:(一)背景知识 https://www.cnblogs.com/yc-9527/p/11598844.html
深度残差收缩网络:(二)整体思路 https://www.cnblogs.com/yc-9527/p/11601322.html
深度残差收缩网络:(三)网络结构 https://www.cnblogs.com/yc-9527/p/11603320.html
深度残差收缩网络:(四)注意力机制下的阈值设置 https://www.cnblogs.com/yc-9527/p/11604082.html
深度残差收缩网络:(五)实验验证 https://www.cnblogs.com/yc-9527/p/11610073.html
深度残差收缩网络:(六)代码实现 https://www.cnblogs.com/yc-9527/p/12091581.html
论文网址:
M. Zhao, S. Zhong, X. Fu, B. Tang, and M. Pecht, “Deep Residual Shrinkage Networks for Fault Diagnosis,” IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096

相关文章
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
|
2月前
|
机器学习/深度学习 存储 算法
回声状态网络(Echo State Networks,ESN)详细原理讲解及Python代码实现
本文详细介绍了回声状态网络(Echo State Networks, ESN)的基本概念、优点、缺点、储层计算范式,并提供了ESN的Python代码实现,包括不考虑和考虑超参数的两种ESN实现方式,以及使用ESN进行时间序列预测的示例。
80 4
回声状态网络(Echo State Networks,ESN)详细原理讲解及Python代码实现
|
5天前
|
安全 C#
某网络硬盘网站被植入传播Trojan.DL.Inject.xz等的代码
某网络硬盘网站被植入传播Trojan.DL.Inject.xz等的代码
完成切换网络+修改网络连接图标提示的代码框架
完成切换网络+修改网络连接图标提示的代码框架
|
2月前
|
安全 网络安全 开发者
探索Python中的装饰器:简化代码,增强功能网络安全与信息安全:从漏洞到防护
【8月更文挑战第30天】本文通过深入浅出的方式介绍了Python中装饰器的概念、用法和高级应用。我们将从基础的装饰器定义开始,逐步深入到如何利用装饰器来改进代码结构,最后探讨其在Web框架中的应用。适合有一定Python基础的开发者阅读,旨在帮助读者更好地理解并运用装饰器来优化他们的代码。
|
2月前
|
数据采集 量子技术 双11
【2023 年第十三届 MathorCup 高校数学建模挑战赛】C 题 电商物流网络包裹应急调运与结构优化问题 建模方案及代码实现
本文提供了2023年第十三届MathorCup高校数学建模挑战赛C题的详细建模方案及代码实现,针对电商物流网络中的包裹应急调运与结构优化问题,提出了包括时间序列分析在内的多种数学模型,并探讨了物流网络的鲁棒性。
36 2
【2023 年第十三届 MathorCup 高校数学建模挑战赛】C 题 电商物流网络包裹应急调运与结构优化问题 建模方案及代码实现
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch代码实现神经网络
这段代码示例展示了如何在PyTorch中构建一个基础的卷积神经网络(CNN)。该网络包括两个卷积层,分别用于提取图像特征,每个卷积层后跟一个池化层以降低空间维度;之后是三个全连接层,用于分类输出。此结构适用于图像识别任务,并可根据具体应用调整参数与层数。
|
2月前
|
达摩院 供应链 JavaScript
网络流问题--仓储物流调度【数学规划的应用(含代码)】阿里达摩院MindOpt
本文通过使用MindOpt工具优化仓储物流调度问题,旨在提高物流效率并降低成本。首先,通过考虑供需匹配、运输时间与距离、车辆容量、仓库储存能力等因素构建案例场景。接着,利用数学规划方法,包括线性规划和网络流问题,来建立模型。在网络流问题中,通过定义节点(资源)和边(资源间的关系),确保流量守恒和容量限制条件下找到最优解。文中还详细介绍了MindOpt Studio云建模平台和MindOpt APL建模语言的应用,并通过实例展示了如何声明集合、参数、变量、目标函数及约束条件,并最终解析了求解结果。通过这些步骤,实现了在满足各仓库需求的同时最小化运输成本的目标。
|
2月前
|
开发者 图形学 API
从零起步,深度揭秘:运用Unity引擎及网络编程技术,一步步搭建属于你的实时多人在线对战游戏平台——详尽指南与实战代码解析,带你轻松掌握网络化游戏开发的核心要领与最佳实践路径
【8月更文挑战第31天】构建实时多人对战平台是技术与创意的结合。本文使用成熟的Unity游戏开发引擎,从零开始指导读者搭建简单的实时对战平台。内容涵盖网络架构设计、Unity网络API应用及客户端与服务器通信。首先,创建新项目并选择适合多人游戏的模板,使用推荐的网络传输层。接着,定义基本玩法,如2D多人射击游戏,创建角色预制件并添加Rigidbody2D组件。然后,引入网络身份组件以同步对象状态。通过示例代码展示玩家控制逻辑,包括移动和发射子弹功能。最后,设置服务器端逻辑,处理客户端连接和断开。本文帮助读者掌握构建Unity多人对战平台的核心知识,为进一步开发打下基础。
66 0
|
2月前
|
安全 开发者 数据安全/隐私保护
Xamarin 的安全性考虑与最佳实践:从数据加密到网络防护,全面解析构建安全移动应用的六大核心技术要点与实战代码示例
【8月更文挑战第31天】Xamarin 的安全性考虑与最佳实践对于构建安全可靠的跨平台移动应用至关重要。本文探讨了 Xamarin 开发中的关键安全因素,如数据加密、网络通信安全、权限管理等,并提供了 AES 加密算法的代码示例。
36 0
下一篇
无影云桌面