RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。


本文基于论文Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks对RPN候选区域网络进行解析说明,并基于PyTorch库对RPN网络进行编程。


创作本文的背景:


在我的专栏【PyTorch实例实战演练】中,已经有了4篇左右文章专门介绍图像分类相关的深度学习模型算法,对这类任务已经比较熟悉了。更进一步地,我想再学习掌握目标检测任务。


目标检测任务相比图像分类任务更加复杂,因为前者不仅要像后者一样找到特征并分类,而且还要定位不同特征的位置。由于图像分类算法已经有比较成熟的研究基础,因此目标检测任务的难点不在于分类,而在于定位!


1. 区域候选算法(RPA,Region Proposal Algorithm)


为了定位图像中的目标分类物体位置,我们可以在整个图像或者放大后的部分图像上移动一个固定大小的矩形窗口,对于每一个窗口位置,都会运行分类器来判断窗口内是否包含目标分类的对象。这种方法称为滑动窗口法,滑动窗口法是一种早期广泛应用于目标检测的传统计算机视觉技术。


滑动窗口法的缺点十分明显——过于简单粗暴!会在大量的没有意义的位置上浪费检测目标的时间。如果不采用一些“聪明的”方法先预选一些比较有可能出现检测目标的位置,或者说如果不让分类器“注意力”集中在可能出现检测目标的位置,实时的目标检测方法就不可能,我们现在诸如依赖机器视觉的自动驾驶等技术也就也可能。


我们把此类能“预选可能出现检测目标位置”的算法称为RPA(Region Proposal Algorithm)区域候选算法,常见的RPA有:


Region Proposal Algorithms (RPA) 是计算机视觉领域中的重要组件,它们主要用于目标检测任务中生成候选区域(region proposals),这些候选区域可能是潜在的对象位置。以下是几个常见的Region Proposal方法:


  1. Selective Search, SS:通过图像金字塔和图像分块策略,进行一系列的过分割、合并操作,生成高质量的候选框。
  2. Edge Boxes, EB:该算法根据边界框内的边缘密度来生成候选框,倾向于选择包含显著物体边界的框。
  3. Region Proposal Network, RPN:首次在Faster R-CNN中引入,RPN是一个全卷积网络,它可以并行地生成多个候选框及其相应的置信度评分,极大地提高了目标检测的速度和性能。


本文要介绍的算法就是RPN。


2. 区域候选网络(RPN,Region Proposal Algorithm)

在正式详细说明RPN之前,我想先总结下Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks这篇文章的几点精华,提炼要点:


  • RPN是一个全卷积网络,能同时预测候选区域的边界置信度
  • RPN的提出目的是为了提升目标检测的效率(速度),使用RPN后与检测任务相比,候选区域的选择几乎是不耗费计算资源(nearly cost-free);
  • RPN能提升目标检测效率的原因是RPN与分类器共用要检测的原图的卷积特征图
  • 首次提出锚框(anchor)的概念,在锚框提出之前候选框的选择只能通过各种尺寸的图像(pyramids of images)或者各种尺寸的算子(pyramids of filters),这也是为了提升目标检测的效率。


2.1 RPN背景说明

在当时(2016年),先进的目标检测网络如SPPnet、Fast R-CNN提出后,减少了目标识别任务的计算时间,使得区域候选成为了计算瓶颈。当时已有的区域候选算法:Selective Search, SS仅能做到2s完成一张图像的区域候选;EdgeBoxes虽然更快,能做到0.2s完成,但是这是以牺牲计算质量为代价的。


通过发现:给识别器(detector)做图像识别的卷积特征图也可以共用于生成候选区域,这样做可以减少计算生成候选区域的时间(少到每张图像的候选区域生成仅需要10ms左右),于是基于此提出了RPN网络。使用RPN区域候选+深度学习图像分类组成的目标检测识别算法能做到每秒完成5个目标检测任务,大大提升了检测速率。


2.2 RPN架构

RPN在目标检测算法Faster R-CNN中的架构如下图:

Faster R-CNN = Fast R-CNN + RPN

RPN是整个 Faster R-CNN 模型的重要组成部分,原始图像首先通过一系列卷积层进行特征提取和处理得到特征图,RPN基于该特征图生成可能包含感兴趣对象的候选区域,以便后续的分类器能把“注意力”集中在这些后续区域内,提升整个目标检测任务的效率。


然后我们再来看下RPN本身的架构:

RPN是一个全卷积网络,模型计算过程如下:

  1. 使用一个n×n滑动窗口(sliding window)在共享的卷积特征图上滑动并提取特征。每个滑动窗口的中间位置都对应有k个锚框(anchors),在本文中使用3种缩放比例×3种长宽比,共k=9种锚框
  2. 对于每一个滑动窗口,通过卷积计算将其转换为一个固定维度向量(例如基础CNN为ZFnet时,该向量长度为256;基础CNN为VGG时,该向量长度为512),通常称为"中间层"或"隐藏层";
  3. 然后使用这个中间层作为输入,分别经过两个孪生全连接层分支(two sibling fully connected layers):分类层和回归层;
  4. 分类层输出的是每个格子属于各个类别(如人、车等)的概率分布,需要注意的是这里输出的是有物体和无物体,即正样本(前景)和负样本(背景)两个概率,因此输出长度为2k;而回归层则输出的是每个格子中心点相对于真实边界框偏移的距离值,中心坐标偏移加上长宽偏移共4个值,因此输出长度为4k

当然,分类层也可以只输出有物体的概率,输出长度为k。原文有一行注释:For simplicity we implement the cls layer as a two-class softmax layer. Alternatively, one may use logistic regression to produce k scores.

RPN的一个重要特性就是平移不变性(translation invariant),带来的好处就是减小模型,更小的参数量能进一步提高计算速度。

2.2 RPN的损失函数

RPN的损失函数为:

  • 为归一化参数,λ为平衡分类损失和回归损失的系数;
  • 为分类损失函数, 为回归损失函数;
  • 角标i代表第i个锚框;
  • 为模型预测的第i个锚框中有物体的概率;
  • 为第i个锚框中是否有物体的ground truth,当①锚框与真实物体的IoU=0.7 或②锚框与真实物体的IoU达到最大为正样本, 。当③锚框与真实物体的IoU<0.3(且不为最大)为负样本,

如果 不满足①②③中任意条件,比如 为0.4,且不为最大值怎么办呢?

答:这种样本不会被选来训练。在训练RPN的时候,每张图像会随机选取256个锚框进行训练,其中正样本和负样本为1:1,选取的样本肯定满足①②③条件之一。

  • 是一个四维向量,代表预测的bounding box相对于第i个锚框x, y, w, h的偏置;
  • 是一个四维向量,代表ground truth相对于第i个锚框x, y, w, h的偏置;

其中

其中

x, y 为框的中心点坐标,w, h为框的宽和高。无角标、角标为a、角标为*分别代表预测的bounding box、锚框、ground truth的框的几何尺寸。

2.3 训练RPN

在Faster R-CNN的目标检测框架中,交替训练(Alternating training)是一种分阶段优化RPN(Region Proposal Network)和Fast R-CNN检测器的策略。以下是详细的交替训练过程:

第一步:初始化与训练RPN

首先,仅使用卷积特征层训练RPN网络。RPN是一个全卷积网络,它直接从基础CNN网络(如VGG或ResNet)提取的特征图上生成候选区域(region proposals)。RPN通过学习调整一组预先定义好的锚框来预测每个锚框是否包含对象以及调整其边界框回归参数。

第二步:使用RPN生成提议并训练Fast R-CNN

使用第一步训练得到的RPN生成大量的候选区域提议,并从中采样固定数量的高质量提议用于训练Fast R-CNN分类器和边框回归器。Fast R-CNN是一个端到端可训练的模型,它对这些候选区域进行分类和精调边界框位置。

第三步:使用Fast R-CNN初始化RPN并重新训练

将经过Fast R-CNN训练后更新了卷积特征层的整个网络(包括基础CNN部分)作为初始化权重,再次启动RPN的训练。在这个迭代过程中,Fast R-CNN的部分被冻结(即保持参数不变),只训练RPN部分,利用更高质量的特征来进一步优化RPN的提议生成能力。

第四步:重复迭代

这个过程可以多次迭代,每次都是将更新后的模型反过来影响RPN的训练,然后再用改进后的RPN提议去提升Fast R-CNN的表现,直到两个子网络收敛达到较好的联合性能。

通过这种方式,RPN和Fast R-CNN可以互相促进,共同优化目标检测的整体性能。在实际训练过程中,可能需要几个交替迭代周期才能使得RPN产生的提议质量和最终检测结果达到最优状态。

3. 基于PyTorch框架的RPN

注意:由于RPN与其他网络模型关系十分密切,难易独立分割开来。以下代码只是一个非常基础的示意,并未包含诸如锚框生成、前向传播中的空间尺寸变换、损失函数定义以及后处理步骤(如非极大值抑制NMS)等内容。在实际项目中,请参考Faster R-CNN论文或其他开源实现来完善整个RPN模块的功能。

import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
 
# RPN网络定义
class RPN(nn.Module):  #softmax???   长宽计算方法??
    def __init__(self, in_channels=512):
        super(RPN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.cls_score = nn.Conv2d(512, 2 * 9, kernel_size=1)  #k=9
        self.bbox_pred = nn.Conv2d(512, 4 * 9, kernel_size=1)
 
    def forward(self, x):
        x = self.relu(self.conv1(x))
        rpn_cls_score = self.cls_score(x)  #生成预测正负样本(有无物体的概率)[pos_label, neg_label]
        rpn_cls_score_softmax = F.softmax(rpn_cls_score)
        rpn_bbox_pred = self.bbox_pred(x)  #生成bounding box的[tx, ty, tw, th]
        return rpn_cls_score_softmax, rpn_bbox_pred
 
# 使用预训练的VGG16作为特征提取器
backbone = vgg16(pretrained=True).features[:-1]  # 模型微调,去掉最后一个池化层
rpn = RPN(in_channels=512)  # VGG16最后一层输出维度为512
 
# 定义损失函数和优化器
cls_loss_func = nn.CrossEntropyLoss()
bbox_loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.SGD(rpn.parameters(), lr=0.001, momentum=0.9)
 
# 假设我们有一个训练数据加载器,格式为images, (gt_boxes, gt_labels)
data_loader = ...
 
# 训练过程,这里仅说明RPN的训练过程,即交替训练的第一步!
num_epochs = 10
for epoch in range(num_epochs):
    for images, (gt_boxes, gt_labels) in enumerate(data_loader):
        # 前向传播
        features = backbone(images)
        rpn_cls_scores_softmax, rpn_bbox_preds = rpn(features)
 
        # 数据预处理,将预测结果调整到与ground truth相匹配的格式
        # 这部分会根据你的具体实现有所不同,这里仅作示例
        rpn_cls_scores_view = rpn_cls_scores_softmax.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_bbox_preds_view = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous().view(-1, 4)
        gt_labels_view = gt_labels.view(-1)
 
        # 计算损失
        rpn_cls_loss = cls_loss_func(rpn_cls_scores_view, gt_labels_view)
        rpn_bbox_loss = bbox_loss_func(rpn_bbox_preds_view, gt_boxes)
 
        # 总损失,这里暂时忽略Ncls,Nreg,λ
        loss = rpn_cls_loss + rpn_bbox_loss
 
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        # 打印训练信息
        if (images + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{images+1}/{len(data_loader)}], Loss: {loss.item()}')
 
# 训练结束
print('Training finished.')


相关文章
|
7天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
121 55
|
17天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
100 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
16天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
106 30
|
4天前
|
JSON 算法 Java
Nettyの网络聊天室&扩展序列化算法
通过本文的介绍,我们详细讲解了如何使用Netty构建一个简单的网络聊天室,并扩展序列化算法以提高数据传输效率。Netty的高性能和灵活性使其成为实现各种网络应用的理想选择。希望本文能帮助您更好地理解和使用Netty进行网络编程。
23 12
|
11天前
|
机器学习/深度学习 算法 信息无障碍
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如"How are you"、"I am fine"、"I love you"等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。
|
20天前
|
存储 算法
深入解析PID控制算法:从理论到实践的完整指南
前言 大家好,今天我们介绍一下经典控制理论中的PID控制算法,并着重讲解该算法的编码实现,为实现后续的倒立摆样例内容做准备。 众所周知,掌握了 PID ,就相当于进入了控制工程的大门,也能为更高阶的控制理论学习打下基础。 在很多的自动化控制领域。都会遇到PID控制算法,这种算法具有很好的控制模式,可以让系统具有很好的鲁棒性。 基本介绍 PID 深入理解 (1)闭环控制系统:讲解 PID 之前,我们先解释什么是闭环控制系统。简单说就是一个有输入有输出的系统,输入能影响输出。一般情况下,人们也称输出为反馈,因此也叫闭环反馈控制系统。比如恒温水池,输入就是加热功率,输出就是水温度;比如冷库,
150 15
|
14天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于深度学习网络的宝石类型识别算法matlab仿真
本项目利用GoogLeNet深度学习网络进行宝石类型识别,实验包括收集多类宝石图像数据集并按7:1:2比例划分。使用Matlab2022a实现算法,提供含中文注释的完整代码及操作视频。GoogLeNet通过其独特的Inception模块,结合数据增强、学习率调整和正则化等优化手段,有效提升了宝石识别的准确性和效率。
|
20天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-GRU网络的数据分类识别算法matlab仿真
本项目展示了使用MATLAB2022a实现的贝叶斯优化、CNN和GRU算法优化效果。优化前后对比显著,完整代码附带中文注释及操作视频。贝叶斯优化适用于黑盒函数,CNN用于时间序列特征提取,GRU改进了RNN的长序列处理能力。
|
21天前
|
SQL 安全 算法
网络安全之盾:漏洞防御与加密技术解析
在数字时代的浪潮中,网络安全和信息安全成为维护个人隐私和企业资产的重要防线。本文将深入探讨网络安全的薄弱环节—漏洞,并分析如何通过加密技术来加固这道防线。文章还将分享提升安全意识的重要性,以预防潜在的网络威胁,确保数据的安全与隐私。
39 2
|
22天前
|
安全 算法 网络安全
网络安全的盾牌与剑:漏洞防御与加密技术深度解析
在数字信息的海洋中,网络安全是航行者不可或缺的指南针。本文将深入探讨网络安全的两大支柱——漏洞防御和加密技术,揭示它们如何共同构筑起信息时代的安全屏障。从最新的网络攻击手段到防御策略,再到加密技术的奥秘,我们将一起揭开网络安全的神秘面纱,理解其背后的科学原理,并掌握保护个人和企业数据的关键技能。
27 3

推荐镜像

更多