RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。


本文基于论文Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks对RPN候选区域网络进行解析说明,并基于PyTorch库对RPN网络进行编程。


创作本文的背景:


在我的专栏【PyTorch实例实战演练】中,已经有了4篇左右文章专门介绍图像分类相关的深度学习模型算法,对这类任务已经比较熟悉了。更进一步地,我想再学习掌握目标检测任务。


目标检测任务相比图像分类任务更加复杂,因为前者不仅要像后者一样找到特征并分类,而且还要定位不同特征的位置。由于图像分类算法已经有比较成熟的研究基础,因此目标检测任务的难点不在于分类,而在于定位!


1. 区域候选算法(RPA,Region Proposal Algorithm)


为了定位图像中的目标分类物体位置,我们可以在整个图像或者放大后的部分图像上移动一个固定大小的矩形窗口,对于每一个窗口位置,都会运行分类器来判断窗口内是否包含目标分类的对象。这种方法称为滑动窗口法,滑动窗口法是一种早期广泛应用于目标检测的传统计算机视觉技术。


滑动窗口法的缺点十分明显——过于简单粗暴!会在大量的没有意义的位置上浪费检测目标的时间。如果不采用一些“聪明的”方法先预选一些比较有可能出现检测目标的位置,或者说如果不让分类器“注意力”集中在可能出现检测目标的位置,实时的目标检测方法就不可能,我们现在诸如依赖机器视觉的自动驾驶等技术也就也可能。


我们把此类能“预选可能出现检测目标位置”的算法称为RPA(Region Proposal Algorithm)区域候选算法,常见的RPA有:


Region Proposal Algorithms (RPA) 是计算机视觉领域中的重要组件,它们主要用于目标检测任务中生成候选区域(region proposals),这些候选区域可能是潜在的对象位置。以下是几个常见的Region Proposal方法:


  1. Selective Search, SS:通过图像金字塔和图像分块策略,进行一系列的过分割、合并操作,生成高质量的候选框。
  2. Edge Boxes, EB:该算法根据边界框内的边缘密度来生成候选框,倾向于选择包含显著物体边界的框。
  3. Region Proposal Network, RPN:首次在Faster R-CNN中引入,RPN是一个全卷积网络,它可以并行地生成多个候选框及其相应的置信度评分,极大地提高了目标检测的速度和性能。


本文要介绍的算法就是RPN。


2. 区域候选网络(RPN,Region Proposal Algorithm)

在正式详细说明RPN之前,我想先总结下Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks这篇文章的几点精华,提炼要点:


  • RPN是一个全卷积网络,能同时预测候选区域的边界置信度
  • RPN的提出目的是为了提升目标检测的效率(速度),使用RPN后与检测任务相比,候选区域的选择几乎是不耗费计算资源(nearly cost-free);
  • RPN能提升目标检测效率的原因是RPN与分类器共用要检测的原图的卷积特征图
  • 首次提出锚框(anchor)的概念,在锚框提出之前候选框的选择只能通过各种尺寸的图像(pyramids of images)或者各种尺寸的算子(pyramids of filters),这也是为了提升目标检测的效率。


2.1 RPN背景说明

在当时(2016年),先进的目标检测网络如SPPnet、Fast R-CNN提出后,减少了目标识别任务的计算时间,使得区域候选成为了计算瓶颈。当时已有的区域候选算法:Selective Search, SS仅能做到2s完成一张图像的区域候选;EdgeBoxes虽然更快,能做到0.2s完成,但是这是以牺牲计算质量为代价的。


通过发现:给识别器(detector)做图像识别的卷积特征图也可以共用于生成候选区域,这样做可以减少计算生成候选区域的时间(少到每张图像的候选区域生成仅需要10ms左右),于是基于此提出了RPN网络。使用RPN区域候选+深度学习图像分类组成的目标检测识别算法能做到每秒完成5个目标检测任务,大大提升了检测速率。


2.2 RPN架构

RPN在目标检测算法Faster R-CNN中的架构如下图:

Faster R-CNN = Fast R-CNN + RPN

RPN是整个 Faster R-CNN 模型的重要组成部分,原始图像首先通过一系列卷积层进行特征提取和处理得到特征图,RPN基于该特征图生成可能包含感兴趣对象的候选区域,以便后续的分类器能把“注意力”集中在这些后续区域内,提升整个目标检测任务的效率。


然后我们再来看下RPN本身的架构:

RPN是一个全卷积网络,模型计算过程如下:

  1. 使用一个n×n滑动窗口(sliding window)在共享的卷积特征图上滑动并提取特征。每个滑动窗口的中间位置都对应有k个锚框(anchors),在本文中使用3种缩放比例×3种长宽比,共k=9种锚框
  2. 对于每一个滑动窗口,通过卷积计算将其转换为一个固定维度向量(例如基础CNN为ZFnet时,该向量长度为256;基础CNN为VGG时,该向量长度为512),通常称为"中间层"或"隐藏层";
  3. 然后使用这个中间层作为输入,分别经过两个孪生全连接层分支(two sibling fully connected layers):分类层和回归层;
  4. 分类层输出的是每个格子属于各个类别(如人、车等)的概率分布,需要注意的是这里输出的是有物体和无物体,即正样本(前景)和负样本(背景)两个概率,因此输出长度为2k;而回归层则输出的是每个格子中心点相对于真实边界框偏移的距离值,中心坐标偏移加上长宽偏移共4个值,因此输出长度为4k

当然,分类层也可以只输出有物体的概率,输出长度为k。原文有一行注释:For simplicity we implement the cls layer as a two-class softmax layer. Alternatively, one may use logistic regression to produce k scores.

RPN的一个重要特性就是平移不变性(translation invariant),带来的好处就是减小模型,更小的参数量能进一步提高计算速度。

2.2 RPN的损失函数

RPN的损失函数为:

  • 为归一化参数,λ为平衡分类损失和回归损失的系数;
  • 为分类损失函数, 为回归损失函数;
  • 角标i代表第i个锚框;
  • 为模型预测的第i个锚框中有物体的概率;
  • 为第i个锚框中是否有物体的ground truth,当①锚框与真实物体的IoU=0.7 或②锚框与真实物体的IoU达到最大为正样本, 。当③锚框与真实物体的IoU<0.3(且不为最大)为负样本,

如果 不满足①②③中任意条件,比如 为0.4,且不为最大值怎么办呢?

答:这种样本不会被选来训练。在训练RPN的时候,每张图像会随机选取256个锚框进行训练,其中正样本和负样本为1:1,选取的样本肯定满足①②③条件之一。

  • 是一个四维向量,代表预测的bounding box相对于第i个锚框x, y, w, h的偏置;
  • 是一个四维向量,代表ground truth相对于第i个锚框x, y, w, h的偏置;

其中

其中

x, y 为框的中心点坐标,w, h为框的宽和高。无角标、角标为a、角标为*分别代表预测的bounding box、锚框、ground truth的框的几何尺寸。

2.3 训练RPN

在Faster R-CNN的目标检测框架中,交替训练(Alternating training)是一种分阶段优化RPN(Region Proposal Network)和Fast R-CNN检测器的策略。以下是详细的交替训练过程:

第一步:初始化与训练RPN

首先,仅使用卷积特征层训练RPN网络。RPN是一个全卷积网络,它直接从基础CNN网络(如VGG或ResNet)提取的特征图上生成候选区域(region proposals)。RPN通过学习调整一组预先定义好的锚框来预测每个锚框是否包含对象以及调整其边界框回归参数。

第二步:使用RPN生成提议并训练Fast R-CNN

使用第一步训练得到的RPN生成大量的候选区域提议,并从中采样固定数量的高质量提议用于训练Fast R-CNN分类器和边框回归器。Fast R-CNN是一个端到端可训练的模型,它对这些候选区域进行分类和精调边界框位置。

第三步:使用Fast R-CNN初始化RPN并重新训练

将经过Fast R-CNN训练后更新了卷积特征层的整个网络(包括基础CNN部分)作为初始化权重,再次启动RPN的训练。在这个迭代过程中,Fast R-CNN的部分被冻结(即保持参数不变),只训练RPN部分,利用更高质量的特征来进一步优化RPN的提议生成能力。

第四步:重复迭代

这个过程可以多次迭代,每次都是将更新后的模型反过来影响RPN的训练,然后再用改进后的RPN提议去提升Fast R-CNN的表现,直到两个子网络收敛达到较好的联合性能。

通过这种方式,RPN和Fast R-CNN可以互相促进,共同优化目标检测的整体性能。在实际训练过程中,可能需要几个交替迭代周期才能使得RPN产生的提议质量和最终检测结果达到最优状态。

3. 基于PyTorch框架的RPN

注意:由于RPN与其他网络模型关系十分密切,难易独立分割开来。以下代码只是一个非常基础的示意,并未包含诸如锚框生成、前向传播中的空间尺寸变换、损失函数定义以及后处理步骤(如非极大值抑制NMS)等内容。在实际项目中,请参考Faster R-CNN论文或其他开源实现来完善整个RPN模块的功能。

import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
 
# RPN网络定义
class RPN(nn.Module):  #softmax???   长宽计算方法??
    def __init__(self, in_channels=512):
        super(RPN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.cls_score = nn.Conv2d(512, 2 * 9, kernel_size=1)  #k=9
        self.bbox_pred = nn.Conv2d(512, 4 * 9, kernel_size=1)
 
    def forward(self, x):
        x = self.relu(self.conv1(x))
        rpn_cls_score = self.cls_score(x)  #生成预测正负样本(有无物体的概率)[pos_label, neg_label]
        rpn_cls_score_softmax = F.softmax(rpn_cls_score)
        rpn_bbox_pred = self.bbox_pred(x)  #生成bounding box的[tx, ty, tw, th]
        return rpn_cls_score_softmax, rpn_bbox_pred
 
# 使用预训练的VGG16作为特征提取器
backbone = vgg16(pretrained=True).features[:-1]  # 模型微调,去掉最后一个池化层
rpn = RPN(in_channels=512)  # VGG16最后一层输出维度为512
 
# 定义损失函数和优化器
cls_loss_func = nn.CrossEntropyLoss()
bbox_loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.SGD(rpn.parameters(), lr=0.001, momentum=0.9)
 
# 假设我们有一个训练数据加载器,格式为images, (gt_boxes, gt_labels)
data_loader = ...
 
# 训练过程,这里仅说明RPN的训练过程,即交替训练的第一步!
num_epochs = 10
for epoch in range(num_epochs):
    for images, (gt_boxes, gt_labels) in enumerate(data_loader):
        # 前向传播
        features = backbone(images)
        rpn_cls_scores_softmax, rpn_bbox_preds = rpn(features)
 
        # 数据预处理,将预测结果调整到与ground truth相匹配的格式
        # 这部分会根据你的具体实现有所不同,这里仅作示例
        rpn_cls_scores_view = rpn_cls_scores_softmax.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_bbox_preds_view = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous().view(-1, 4)
        gt_labels_view = gt_labels.view(-1)
 
        # 计算损失
        rpn_cls_loss = cls_loss_func(rpn_cls_scores_view, gt_labels_view)
        rpn_bbox_loss = bbox_loss_func(rpn_bbox_preds_view, gt_boxes)
 
        # 总损失,这里暂时忽略Ncls,Nreg,λ
        loss = rpn_cls_loss + rpn_bbox_loss
 
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        # 打印训练信息
        if (images + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{images+1}/{len(data_loader)}], Loss: {loss.item()}')
 
# 训练结束
print('Training finished.')


相关文章
|
3天前
|
安全 虚拟化
在数字化时代,网络项目的重要性日益凸显。本文从前期准备、方案内容和注意事项三个方面,详细解析了如何撰写一个优质高效的网络项目实施方案,帮助企业和用户实现更好的体验和竞争力
在数字化时代,网络项目的重要性日益凸显。本文从前期准备、方案内容和注意事项三个方面,详细解析了如何撰写一个优质高效的网络项目实施方案,帮助企业和用户实现更好的体验和竞争力。通过具体案例,展示了方案的制定和实施过程,强调了目标明确、技术先进、计划周密、风险可控和预算合理的重要性。
15 5
|
5天前
|
SQL 安全 网络安全
网络安全的护城河:漏洞防御与加密技术的深度解析
【10月更文挑战第37天】在数字时代的浪潮中,网络安全成为守护个人隐私与企业资产的坚固堡垒。本文将深入探讨网络安全的两大核心要素——安全漏洞和加密技术,以及如何通过提升安全意识来强化这道防线。文章旨在揭示网络攻防战的复杂性,并引导读者构建更为稳固的安全体系。
16 1
|
14天前
|
SQL 安全 测试技术
网络安全的盾牌与剑——漏洞防御与加密技术解析
【10月更文挑战第28天】 在数字时代的浪潮中,网络空间安全成为我们不可忽视的战场。本文将深入探讨网络安全的核心问题,包括常见的网络安全漏洞、先进的加密技术以及提升个人和组织的安全意识。通过实际案例分析和代码示例,我们将揭示黑客如何利用漏洞进行攻击,展示如何使用加密技术保护数据,并强调培养网络安全意识的重要性。让我们一同揭开网络安全的神秘面纱,为打造更加坚固的数字防线做好准备。
34 3
RS-485网络中的标准端接与交流电端接应用解析
RS-485,作为一种广泛应用的差分信号传输标准,因其传输距离远、抗干扰能力强、支持多点通讯等优点,在工业自动化、智能建筑、交通运输等领域得到了广泛应用。在构建RS-485网络时,端接技术扮演着至关重要的角色,它直接影响到网络的信号完整性、稳定性和通信质量。
|
3天前
|
网络协议 网络安全 网络虚拟化
本文介绍了十个重要的网络技术术语,包括IP地址、子网掩码、域名系统(DNS)、防火墙、虚拟专用网络(VPN)、路由器、交换机、超文本传输协议(HTTP)、传输控制协议/网际协议(TCP/IP)和云计算
本文介绍了十个重要的网络技术术语,包括IP地址、子网掩码、域名系统(DNS)、防火墙、虚拟专用网络(VPN)、路由器、交换机、超文本传输协议(HTTP)、传输控制协议/网际协议(TCP/IP)和云计算。通过这些术语的详细解释,帮助读者更好地理解和应用网络技术,应对数字化时代的挑战和机遇。
22 3
|
3天前
|
存储 网络协议 安全
30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场
本文精选了 30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场。
13 2
|
22天前
|
存储 安全 网络安全
网络安全的屏障与钥匙:漏洞防御与加密技术深度解析
【10月更文挑战第20天】在数字世界的迷宫中,网络安全是守护我们数据宝藏的坚固盾牌和锋利钥匙。本篇文章将带您穿梭于网络的缝隙之间,揭示那些潜藏的脆弱点—网络安全漏洞,同时探索如何通过现代加密技术加固我们的数字堡垒。从基本概念到实战策略,我们将一同揭开网络安全的神秘面纱,提升您的安全意识,保护个人信息不受侵犯。
51 25
|
17天前
|
边缘计算 自动驾驶 5G
|
11天前
|
SQL 安全 算法
网络安全的屏障与钥匙:漏洞防护与加密技术解析
【10月更文挑战第31天】在数字世界的海洋中,网络安全是航船的坚固屏障,而信息安全则是守护宝藏的金钥匙。本文将深入探讨网络安全的薄弱环节——漏洞,以及如何通过加密技术加固这道屏障。从常见网络漏洞的类型到最新的加密算法,我们不仅提供理论知识,还将分享实用的安全实践技巧,帮助读者构建起一道更加坚不可摧的防线。
21 1
|
1月前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
68 1
目标检测笔记(一):不同模型的网络架构介绍和代码