ResNext架构解析:深度神经网络的聚合残差变换

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: 我们提出了一种用于图像分类的简单、高度模块化的网络架构。我们的网络是通过重复一个构建块来构建的,该构建块聚合了一组具有相同拓扑的转换。我们简单的设计产生了一个同构的多分支架构,只需设置几个超参数。这个策略**暴露了一个新的维度,我们称之为“基数”(转换集的大小)**,作为除了深度和宽度维度之外的重要因素。在 ImageNet-1K 数据集上,我们凭经验表明,即使在保持复杂性的限制条件下,增加基数也能够提高分类精度。此外,当我们增加容量时,增加基数比更深或更宽更有效。我们的模型名为 ResNeXt,是我们进入 ILSVRC 2016 分类任务的基础,在该任务中我们获得了第二名。我们在 Image

1、简介

  我们提出了一种用于图像分类的简单、高度模块化的网络架构。我们的网络是通过重复一个构建块来构建的,该构建块聚合了一组具有相同拓扑的转换。我们简单的设计产生了一个同构的多分支架构,只需设置几个超参数。这个策略暴露了一个新的维度,我们称之为“基数”(转换集的大小),作为除了深度和宽度维度之外的重要因素。在 ImageNet-1K 数据集上,我们凭经验表明,即使在保持复杂性的限制条件下,增加基数也能够提高分类精度。此外,当我们增加容量时,增加基数比更深或更宽更有效。我们的模型名为 ResNeXt,是我们进入 ILSVRC 2016 分类任务的基础,在该任务中我们获得了第二名。我们在 ImageNet-5K 集和 COCO 检测集上进一步研究 ResNeXt,也显示出比 ResNet 更好的结果。代码和模型可在线公开获得1。

  在本文中,我们提出了一个简单的架构,它采用了 VGG/ResNets 的重复层策略,同时以一种简单、可扩展的方式利用了拆分-变换-合并策略。我们网络中的一个模块执行一组转换,每个转换都在一个低维嵌入上,其输出通过求和聚合。我们追求这个想法的简单实现——要聚合的变换都是相同的拓扑结构(例如,图 1(右))。这种设计允许我们在没有专门设计的情况下扩展到任何大量的转换。

image-20220820154538911

左图为ResNet块,右图为C = 32 的 ResNeXt 块,复杂度大致相同。

每层的格式为(输入通道数,卷积核大小,输出通道数)

2、分组卷积

  简单来说分组卷积就是将特征图分为不同的组,再对每组特征图分别进行卷积

  在分组卷积中每个卷积核只处理部分通道,比如下图中,红色卷积核只处理红色的通道,绿色卷积核只处理绿色通道,黄色卷积核只处理黄色通道。此时每个卷积核有2个通道,每个卷积核生成一张特征图。

image-20220818214806233

image-20220818215041434

  左边标准卷积,每个卷积核处理12个通道

  右边分组卷积,假设输入的12个通道分为3组,每个卷积核只处理4个通道

3、残差单元

image-20220820155257554

  ==上图的三种block块在数学计算上面完全等价==

  通常我们使用的是(c)图,论文中说这种结构更简洁、快速。

  每层的格式为(输入通道数,卷积核大小,输出通道数)

  其实从残差快结构可以发现,ResNext-50和ResNet-50的结构非常相似,只需要把原来ResNet的中间那一层换成分组卷积即可。

  从(c)图可以看出,先使用1*1卷积降维,再使用3*3分组卷积提取特征,最后使用1*1卷积升维,如果输入和输出的shape相同,加上残差连接,(输入和输出的shape不一致的时候,考虑1*1卷积升维)。

   没啥讲的,创新点就是用了个分组卷积吧。

4、ResNext-50网络结构

image-20220820160153155

  和ResNet网络架构太像了,就是注意下那个分组卷积,C=32为分组数。

  conv3、conv4、conv5的下采样是在每个阶段的第一个块的额3*3卷积层中通过stride=2的卷积操作实现的。

5、与ResNet模型的比较

image-20220820160515197

image-20220820160527934

  只看验证集,在计算量相同的情况下,ResNext的误差比ResNet更低

image-20220820160632757

   上述都是在ImageNet-1K数据集上做的实验。

6、ResNext-50模型复现

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
    BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

6.1 分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):
    group_list = []
    # 分组进行卷积
    for c in range(groups):
        # 分组取出数据
        x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)
        # 分组进行卷积
        x = Conv2D(filters=g_channels, kernel_size=(3, 3),
                   strides=strides, padding='same', use_bias=False)(x)
        # 存入list
        group_list.append(x)
    # 合并list中的数据
    group_merage = concatenate(group_list, axis=3)
    x = BatchNormalization(epsilon=1.001e-5)(group_merage)
    x = ReLU()(x)
    return x

6.2 定义残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):
    # projection shortcut
    if conv_shortcut:
        shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same',
                          use_bias=False)(x)
        # epsilon为BN公式中防止分母为零的值
        shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)
    else:
        # identity_shortcut
        shortcut = x
    # 3个卷积层
    x = Conv2D(filters =filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = ReLU()(x)
    # 计算每组的通道数
    g_channels = int(filters / groups)
    # 进行分组卷积
    x = grouped_convolution_block(x, strides, groups, g_channels)

    x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x

6.3 堆叠残差单元

  每个stack的第一个block的输入和输出的shape是不一致的,所以残差连接都需要使用1*1卷积升维后才能进行Add操作。

  而其他block的输入和输出的shape是一致的,所以可以直接执行Add操作。

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):
    # 每个stack的第一个block的残差连接都需要使用1*1卷积升维
    x = block(x, filters, strides=strides, groups=groups)
    for i in range(blocks):
        x = block(x, filters, groups=groups, conv_shortcut=False)
    return x

6.4 搭建ResNext-50(32*4d)网络结构

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):
    inputs = Input(shape=input_shape)
    # 填充3圈0,[224,224,3]->[230,230,3]
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = ReLU()(x)
    # 填充1圈0
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)
    # 堆叠残差结构
    x = stack(x, filters=128, blocks=2, strides=1)
    x = stack(x, filters=256, blocks=3, strides=2)
    x = stack(x, filters=512, blocks=5, strides=2)
    x = stack(x, filters=1024, blocks=2, strides=2)
    # 根据特征图大小进行全局平均池化
    x = GlobalAvgPool2D()(x)
    x = Dense(num_classes, activation='softmax')(x)
    # 定义模型
    model = Model(inputs=inputs, outputs=x)
    return model
  上面不使用ZeroPadding2D也是可以的,令第一个卷积和池化的padding='same'即可。

6.5 查看模型摘要

model=ResNext50(input_shape=(224,224,3),num_classes=1000)
model.summary()

6.6 用自定义数据集测试

  我的数据集共17种花,分别放在对应的文件夹中。
model=ResNext50(input_shape=(224,224,3),num_classes=17)

  数据增强

# 训练集数据进行数据增强
train_datagen = ImageDataGenerator(
    rotation_range=20,  # 随机旋转度数
    width_shift_range=0.1,  # 随机水平平移
    height_shift_range=0.1,  # 随机竖直平移
    rescale=1 / 255,  # 数据归一化
    shear_range=10,  # 随机错切变换
    zoom_range=0.1,  # 随机放大
    horizontal_flip=True,  # 水平翻转
    brightness_range=(0.7, 1.3),  # 亮度变化
    fill_mode='nearest',  # 填充方式
)
# 测试集数据只需要归一化就可以
test_datagen = ImageDataGenerator(
    rescale=1 / 255,  # 数据归一化
)

  数据生成器

# 训练集数据生成器,可以在训练时自动产生数据进行训练
# 从'data/train'获得训练集数据
# 获得数据后会把图片resize为image_size×image_size的大小
# generator每次会产生batch_size个数据
train_generator = train_datagen.flow_from_directory(
    '../data/train',
    target_size=(image_size, image_size),
    batch_size=batch_size,
)

# 测试集数据生成器
test_generator = test_datagen.flow_from_directory(
    '../data/test',
    target_size=(image_size, image_size),
    batch_size=batch_size,
)
# 字典的键为17个文件夹的名字,值为对应的分类编号
print(train_generator.class_indices)

image-20220820162355457

  回调设置

# 学习率调节函数,逐渐减小学习率
def adjust_learning_rate(epoch):
    # 前40周期
    if epoch<=40:
        lr = 1e-4
    # 前40到80周期
    elif epoch>40 and epoch<=80:
        lr = 1e-5
    # 80到100周期
    else:
        lr = 1e-6
    return lr

# 定义优化器
adam = Adam(lr=1e-4)

# 读取模型
checkpoint_save_path = "./checkpoint/ResNext-50.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)
# 保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

# 定义学习率衰减策略
callbacks = []
callbacks.append(LearningRateScheduler(adjust_learning_rate))
callbacks.append(cp_callback)

  训练

# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])

# Tensorflow2.1版本(包括2.1)之后可以直接使用fit训练模型
history = model.fit(x=train_generator,epochs=epochs,validation_data=test_generator,callbacks=callbacks)

image-20220820162447405

  acc可视化

# 画出训练集准确率曲线图
plt.plot(np.arange(epochs),history.history['accuracy'],c='b',label='train_accuracy')
# 画出验证集准确率曲线图
plt.plot(np.arange(epochs),history.history['val_accuracy'],c='y',label='val_accuracy')
# 图例
plt.legend()
# x坐标描述
plt.xlabel('epochs')
# y坐标描述
plt.ylabel('accuracy')
# 显示图像
plt.show()

image-20220820162515339

  loss可视化

# 画出训练集loss曲线图
plt.plot(np.arange(epochs),history.history['loss'],c='b',label='train_loss')
# 画出验证集loss曲线图
plt.plot(np.arange(epochs),history.history['val_loss'],c='y',label='val_loss')
# 图例
plt.legend()
# x坐标描述
plt.xlabel('epochs')
# y坐标描述
plt.ylabel('loss')
# 显示图像
plt.show()

image-20220820162545539

References

  Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, & Kaiming He (2016). Aggregated Residual Transformations for Deep Neural Networks computer vision and pattern recognition.
  ResNet架构解析

  6.1.2 ResNext网络结构

  ResNeXt 交通标志四分类,附Tensorflow完整代码

目录
相关文章
|
17天前
|
运维 持续交付 云计算
深入解析云计算中的微服务架构:原理、优势与实践
深入解析云计算中的微服务架构:原理、优势与实践
45 1
|
9天前
|
运维 监控 持续交付
微服务架构解析:跨越传统架构的技术革命
微服务架构(Microservices Architecture)是一种软件架构风格,它将一个大型的单体应用拆分为多个小而独立的服务,每个服务都可以独立开发、部署和扩展。
97 36
微服务架构解析:跨越传统架构的技术革命
|
10天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
85 30
|
14天前
|
存储 Linux API
深入探索Android系统架构:从内核到应用层的全面解析
本文旨在为读者提供一份详尽的Android系统架构分析,从底层的Linux内核到顶层的应用程序框架。我们将探讨Android系统的模块化设计、各层之间的交互机制以及它们如何共同协作以支持丰富多样的应用生态。通过本篇文章,开发者和爱好者可以更深入理解Android平台的工作原理,从而优化开发流程和提升应用性能。
|
16天前
|
弹性计算 持续交付 API
构建高效后端服务:微服务架构的深度解析与实践
在当今快速发展的软件行业中,构建高效、可扩展且易于维护的后端服务是每个技术团队的追求。本文将深入探讨微服务架构的核心概念、设计原则及其在实际项目中的应用,通过具体案例分析,展示如何利用微服务架构解决传统单体应用面临的挑战,提升系统的灵活性和响应速度。我们将从微服务的拆分策略、通信机制、服务发现、配置管理、以及持续集成/持续部署(CI/CD)等方面进行全面剖析,旨在为读者提供一套实用的微服务实施指南。
|
17天前
|
SQL 数据可视化 数据库
多维度解析低代码:从技术架构到插件生态
本文深入解析低代码平台,涵盖技术架构、插件生态及应用价值。通过图形化界面和模块化设计,低代码平台降低开发门槛,提升效率,支持企业快速响应市场变化。重点分析开源低代码平台的优势,如透明架构、兼容性与扩展性、可定制化开发等,探讨其在数据处理、功能模块、插件生态等方面的技术特点,以及未来发展趋势。
|
14天前
|
SQL 安全 算法
网络安全之盾:漏洞防御与加密技术解析
在数字时代的浪潮中,网络安全和信息安全成为维护个人隐私和企业资产的重要防线。本文将深入探讨网络安全的薄弱环节—漏洞,并分析如何通过加密技术来加固这道防线。文章还将分享提升安全意识的重要性,以预防潜在的网络威胁,确保数据的安全与隐私。
31 2
|
16天前
|
安全 算法 网络安全
网络安全的盾牌与剑:漏洞防御与加密技术深度解析
在数字信息的海洋中,网络安全是航行者不可或缺的指南针。本文将深入探讨网络安全的两大支柱——漏洞防御和加密技术,揭示它们如何共同构筑起信息时代的安全屏障。从最新的网络攻击手段到防御策略,再到加密技术的奥秘,我们将一起揭开网络安全的神秘面纱,理解其背后的科学原理,并掌握保护个人和企业数据的关键技能。
23 3
|
16天前
|
SQL 数据可视化 数据库
多维度解析低代码:从技术架构到插件生态
本文深入解析低代码平台,从技术架构到插件生态,探讨其在企业数字化转型中的作用。低代码平台通过图形化界面和模块化设计降低开发门槛,加速应用开发与部署,提高市场响应速度。文章重点分析开源低代码平台的优势,如透明架构、兼容性与扩展性、可定制化开发等,并详细介绍了核心技术架构、数据处理与功能模块、插件生态及数据可视化等方面,展示了低代码平台如何支持企业在数字化转型中实现更高灵活性和创新。
39 1
|
16天前
|
SQL 数据可视化 数据库
多维度解析低代码:从技术架构到插件生态
本文深入解析低代码平台,涵盖技术架构、插件生态及应用价值。重点介绍开源低代码平台的优势,如透明架构、兼容性与扩展性、可定制化开发,以及其在数据处理、功能模块、插件生态等方面的技术特点。文章还探讨了低代码平台的安全性、权限管理及未来技术趋势,强调其在企业数字化转型中的重要作用。
32 1

推荐镜像

更多