[Highway]论文实现:Highway Networks

简介: [Highway]论文实现:Highway Networks

论文:Highway Networks

作者:Rupesh Kumar Srivastava, Klaus Greff, Jürgen Schmidhuber

时间:2015

有大量的理论和经验证据表明,神经网络的深度是其成功的关键因素。然而,随着深度的增加,网络训练变得更加困难,而深度网络的训练仍然是一个开放的问题。在这个扩展的摘要中,论文引入了一种新的体系结构,旨在简化对非常深的网络的基于梯度的训练。论文将具有这种架构的网络称为Highway Networks,因为它们允许信息不受阻碍地流过Highway Networks上的几层。该架构的特点是使用门控单元,以学习调节通过网络的信息流。具有数百层的Highway Networks可以通过随机梯度下降和各种激活函数进行直接训练,为研究极其深层和高效的体系结构提供了可能性。

一、完整代码

这里笔者从实现的角度瞎搭了一个模型,对mnist数据集进行训练,由于是瞎整的架构,不用管效果如何!

import tensorflow as tf
import pandas as pd
# 定义highway
class HighWay(tf.keras.layers.Layer):
    def __init__(self, kernel_size):
        super(HighWay, self).__init__()
        self.kernel_size = kernel_size
    def build(self, input_shape):
        self.h = tf.keras.layers.Conv2D(input_shape[-1], self.kernel_size, padding='same', activation='relu')
        self.t = tf.keras.layers.Conv2D(input_shape[-1], self.kernel_size, padding='same', activation='sigmoid')
    def call(self, inputs):
        h = self.h(inputs)
        t = self.t(inputs)
        return h*t + inputs*(1-t)
# 准备数据
mnist = tf.keras.datasets.mnist
(train_x,train_y),(test_x,test_y) = mnist.load_data()
# 建立模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28,28)),
    tf.keras.layers.Reshape((28,28,1)),
    HighWay(2),
    tf.keras.layers.MaxPool2D(),
    HighWay(2),
    tf.keras.layers.MaxPool2D(),
    HighWay(2),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(300, activation='relu'),
    tf.keras.layers.Dropout(rate=0.5),
    tf.keras.layers.Dense(200, activation='relu'),
    tf.keras.layers.Dropout(rate=0.5),
    tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer='adam',
    metrics=['accuracy']
)
history = model.fit(x=train_x, y=train_y, epochs=10, verbose=1)
# 画图
pd.DataFrame(history.history).plot()

结果如下:

二、论文解读

    由于链式法则,对于传统网络,随着训练层数的递增,每一层所含有的信息会减少,为了避免出现这样的情况, 可以使用Highway Networks解决这一情况,注意:这里解决的不是梯度爆炸或者梯度消失的问题;

2.1 模型架构

深度神经网络的每一层基本可以看做一个非线性转化过程,非常多的这种非线性转化可以表达一些非常复杂的函数,这是神经网络的本质;我们现在对某一层进行分析:

y=H(x,WH)

这里xinputs 输入层, y是 outputs 是输出层, WH 是参数,这个函数表明inputs WH 进行一次非线性转换得到  outputs;这是传统的模型层结构;


可以从这个结构中看出,经过一次非线性转化后, x的信息发生了改变,变成了  y,为了降低信息损失,Highway采用了一种方式:即在 outputs 中加上 inputs, 公式如下:

y=H(x,WH)T(x,WT)+xC(x,WC).

其中  H,  TC表示的是三个非线性函数,W C W_C WC表示的是各个函数中的参数;值得注意的是:这里去掉  T,  C这两个函数即为ResNet模型


其中 T被称为transform gateC被称为carry gate,为了方便起见,这里令 C=1T;模型变为:

y = y=H(x,WH)T(x,WT)+x(1T(x,WC)).

从模型中可以看出,要保证这个等式成立,我们需要满足xy,  H,  T的维度尺寸相等,即此层结构不会影响输出结构;


同时,令 T的非线性为sigmoid,这样能映射到 [01]这个区间内,保证 T,  C都产生一个大于0的数;

2.2 模型效果

三、过程实现

这里以卷积网络为例子,生成Highway Layer的代码如下:

class HighWay(tf.keras.layers.Layer):
    def __init__(self, kernel_size):
        super(HighWay, self).__init__()
        self.kernel_size = kernel_size
    def build(self, input_shape):
        self.h = tf.keras.layers.Conv2D(input_shape[-1], self.kernel_size, padding='same', activation='relu')
        self.t = tf.keras.layers.Conv2D(input_shape[-1], self.kernel_size, padding='same', activation='sigmoid')
    def call(self, inputs):
        h = self.h(inputs)
        t = self.t(inputs)
        return h*t + inputs*(1-t)

四、整体总结

没什么好总结的,感觉ResNet是抄袭Highway;


目录
相关文章
|
NoSQL 关系型数据库 Go
更新Navicat Premium 16.2 之 如何使用Navicat连接Redis的新手教程《更新Navicat Premium 16.2并连接Redis:高效管理数据库和键值存储》
更新Navicat Premium 16.2 之 如何使用Navicat连接Redis的新手教程《更新Navicat Premium 16.2并连接Redis:高效管理数据库和键值存储》
1777 0
更新Navicat Premium 16.2 之 如何使用Navicat连接Redis的新手教程《更新Navicat Premium 16.2并连接Redis:高效管理数据库和键值存储》
|
网络协议 Dubbo Java
一文搞懂NIO、AIO、BIO的核心区别(建议收藏)
本文详细解析了NIO、AIO、BIO的核心区别,NIO的三个核心概念,以及NIO在Java框架中的应用等。关注【mikechen的互联网架构】,10年+BAT架构经验倾囊相授。
一文搞懂NIO、AIO、BIO的核心区别(建议收藏)
|
存储 算法 PyTorch
pytorch 给定概率分布的张量,如何利用这个概率进行重复\不重复采样?
在 PyTorch 中,可以使用 torch.distributions.Categorical 来基于给定的概率分布进行采样。
1555 0
|
自然语言处理 算法 物联网
如何训练一个大模型:LoRA篇
如何训练一个大模型:LoRA篇
3402 1
|
前端开发 Java
Java HotSpot(TM) 64-Bit Server VM warning
Java HotSpot(TM) 64-Bit Server VM warning
5232 1
|
机器学习/深度学习 运维 搜索推荐
机器学习中准确率、精确率、召回率、误报率、漏报率、F1-Score、AP&mAP、AUC、MAE、MAPE、MSE、RMSE、R-Squared等指标的定义和说明
在机器学习和深度学习用于异常检测(Anomaly detection)、电子商务(E-commerce)、信息检索(Information retrieval, IR)等领域任务(Task)中,有很多的指标来判断机器学习和深度学习效果的好坏。这些指标有相互权衡的,有相互背向的,所以往往需要根据实际的任务和场景来选择衡量指标。本篇博文对这些指标进行一个梳理。
机器学习中准确率、精确率、召回率、误报率、漏报率、F1-Score、AP&mAP、AUC、MAE、MAPE、MSE、RMSE、R-Squared等指标的定义和说明
|
JSON 小程序 前端开发
微信小程序--》小程序全局配置和详解下拉刷新和上拉触底页面事件
⚓经过web前端开发的学习,相信大家对于前端开发有了一定深入的了解,今天我开设了微信小程序,主要想从移动端开发方向进一步发展,而对于我来说写移动端博文的第一站就是小程序开发,希望看到我文章的朋友能对你有所帮助。
1359 0
微信小程序--》小程序全局配置和详解下拉刷新和上拉触底页面事件
|
网络协议 SDN 数据安全/隐私保护
|
Kubernetes 安全 调度
k8s教程(pod篇)-亲和性与互斥性调度
k8s教程(pod篇)-亲和性与互斥性调度
915 0

热门文章

最新文章