RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

简介: RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。


本文基于论文Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks对RPN候选区域网络进行解析说明,并基于PyTorch库对RPN网络进行编程。


创作本文的背景:


在我的专栏【PyTorch实例实战演练】中,已经有了4篇左右文章专门介绍图像分类相关的深度学习模型算法,对这类任务已经比较熟悉了。更进一步地,我想再学习掌握目标检测任务。


目标检测任务相比图像分类任务更加复杂,因为前者不仅要像后者一样找到特征并分类,而且还要定位不同特征的位置。由于图像分类算法已经有比较成熟的研究基础,因此目标检测任务的难点不在于分类,而在于定位!


1. 区域候选算法(RPA,Region Proposal Algorithm)


为了定位图像中的目标分类物体位置,我们可以在整个图像或者放大后的部分图像上移动一个固定大小的矩形窗口,对于每一个窗口位置,都会运行分类器来判断窗口内是否包含目标分类的对象。这种方法称为滑动窗口法,滑动窗口法是一种早期广泛应用于目标检测的传统计算机视觉技术。


滑动窗口法的缺点十分明显——过于简单粗暴!会在大量的没有意义的位置上浪费检测目标的时间。如果不采用一些“聪明的”方法先预选一些比较有可能出现检测目标的位置,或者说如果不让分类器“注意力”集中在可能出现检测目标的位置,实时的目标检测方法就不可能,我们现在诸如依赖机器视觉的自动驾驶等技术也就也可能。


我们把此类能“预选可能出现检测目标位置”的算法称为RPA(Region Proposal Algorithm)区域候选算法,常见的RPA有:


Region Proposal Algorithms (RPA) 是计算机视觉领域中的重要组件,它们主要用于目标检测任务中生成候选区域(region proposals),这些候选区域可能是潜在的对象位置。以下是几个常见的Region Proposal方法:


  1. Selective Search, SS:通过图像金字塔和图像分块策略,进行一系列的过分割、合并操作,生成高质量的候选框。
  2. Edge Boxes, EB:该算法根据边界框内的边缘密度来生成候选框,倾向于选择包含显著物体边界的框。
  3. Region Proposal Network, RPN:首次在Faster R-CNN中引入,RPN是一个全卷积网络,它可以并行地生成多个候选框及其相应的置信度评分,极大地提高了目标检测的速度和性能。


本文要介绍的算法就是RPN。


2. 区域候选网络(RPN,Region Proposal Algorithm)

在正式详细说明RPN之前,我想先总结下Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks这篇文章的几点精华,提炼要点:


  • RPN是一个全卷积网络,能同时预测候选区域的边界置信度
  • RPN的提出目的是为了提升目标检测的效率(速度),使用RPN后与检测任务相比,候选区域的选择几乎是不耗费计算资源(nearly cost-free);
  • RPN能提升目标检测效率的原因是RPN与分类器共用要检测的原图的卷积特征图
  • 首次提出锚框(anchor)的概念,在锚框提出之前候选框的选择只能通过各种尺寸的图像(pyramids of images)或者各种尺寸的算子(pyramids of filters),这也是为了提升目标检测的效率。


2.1 RPN背景说明

在当时(2016年),先进的目标检测网络如SPPnet、Fast R-CNN提出后,减少了目标识别任务的计算时间,使得区域候选成为了计算瓶颈。当时已有的区域候选算法:Selective Search, SS仅能做到2s完成一张图像的区域候选;EdgeBoxes虽然更快,能做到0.2s完成,但是这是以牺牲计算质量为代价的。


通过发现:给识别器(detector)做图像识别的卷积特征图也可以共用于生成候选区域,这样做可以减少计算生成候选区域的时间(少到每张图像的候选区域生成仅需要10ms左右),于是基于此提出了RPN网络。使用RPN区域候选+深度学习图像分类组成的目标检测识别算法能做到每秒完成5个目标检测任务,大大提升了检测速率。


2.2 RPN架构

RPN在目标检测算法Faster R-CNN中的架构如下图:

Faster R-CNN = Fast R-CNN + RPN

RPN是整个 Faster R-CNN 模型的重要组成部分,原始图像首先通过一系列卷积层进行特征提取和处理得到特征图,RPN基于该特征图生成可能包含感兴趣对象的候选区域,以便后续的分类器能把“注意力”集中在这些后续区域内,提升整个目标检测任务的效率。


然后我们再来看下RPN本身的架构:

RPN是一个全卷积网络,模型计算过程如下:

  1. 使用一个n×n滑动窗口(sliding window)在共享的卷积特征图上滑动并提取特征。每个滑动窗口的中间位置都对应有k个锚框(anchors),在本文中使用3种缩放比例×3种长宽比,共k=9种锚框
  2. 对于每一个滑动窗口,通过卷积计算将其转换为一个固定维度向量(例如基础CNN为ZFnet时,该向量长度为256;基础CNN为VGG时,该向量长度为512),通常称为"中间层"或"隐藏层";
  3. 然后使用这个中间层作为输入,分别经过两个孪生全连接层分支(two sibling fully connected layers):分类层和回归层;
  4. 分类层输出的是每个格子属于各个类别(如人、车等)的概率分布,需要注意的是这里输出的是有物体和无物体,即正样本(前景)和负样本(背景)两个概率,因此输出长度为2k;而回归层则输出的是每个格子中心点相对于真实边界框偏移的距离值,中心坐标偏移加上长宽偏移共4个值,因此输出长度为4k

当然,分类层也可以只输出有物体的概率,输出长度为k。原文有一行注释:For simplicity we implement the cls layer as a two-class softmax layer. Alternatively, one may use logistic regression to produce k scores.

RPN的一个重要特性就是平移不变性(translation invariant),带来的好处就是减小模型,更小的参数量能进一步提高计算速度。

2.2 RPN的损失函数

RPN的损失函数为:

  • 为归一化参数,λ为平衡分类损失和回归损失的系数;
  • 为分类损失函数, 为回归损失函数;
  • 角标i代表第i个锚框;
  • 为模型预测的第i个锚框中有物体的概率;
  • 为第i个锚框中是否有物体的ground truth,当①锚框与真实物体的IoU=0.7 或②锚框与真实物体的IoU达到最大为正样本, 。当③锚框与真实物体的IoU<0.3(且不为最大)为负样本,

如果 不满足①②③中任意条件,比如 为0.4,且不为最大值怎么办呢?

答:这种样本不会被选来训练。在训练RPN的时候,每张图像会随机选取256个锚框进行训练,其中正样本和负样本为1:1,选取的样本肯定满足①②③条件之一。

  • 是一个四维向量,代表预测的bounding box相对于第i个锚框x, y, w, h的偏置;
  • 是一个四维向量,代表ground truth相对于第i个锚框x, y, w, h的偏置;

其中

其中

x, y 为框的中心点坐标,w, h为框的宽和高。无角标、角标为a、角标为*分别代表预测的bounding box、锚框、ground truth的框的几何尺寸。

2.3 训练RPN

在Faster R-CNN的目标检测框架中,交替训练(Alternating training)是一种分阶段优化RPN(Region Proposal Network)和Fast R-CNN检测器的策略。以下是详细的交替训练过程:

第一步:初始化与训练RPN

首先,仅使用卷积特征层训练RPN网络。RPN是一个全卷积网络,它直接从基础CNN网络(如VGG或ResNet)提取的特征图上生成候选区域(region proposals)。RPN通过学习调整一组预先定义好的锚框来预测每个锚框是否包含对象以及调整其边界框回归参数。

第二步:使用RPN生成提议并训练Fast R-CNN

使用第一步训练得到的RPN生成大量的候选区域提议,并从中采样固定数量的高质量提议用于训练Fast R-CNN分类器和边框回归器。Fast R-CNN是一个端到端可训练的模型,它对这些候选区域进行分类和精调边界框位置。

第三步:使用Fast R-CNN初始化RPN并重新训练

将经过Fast R-CNN训练后更新了卷积特征层的整个网络(包括基础CNN部分)作为初始化权重,再次启动RPN的训练。在这个迭代过程中,Fast R-CNN的部分被冻结(即保持参数不变),只训练RPN部分,利用更高质量的特征来进一步优化RPN的提议生成能力。

第四步:重复迭代

这个过程可以多次迭代,每次都是将更新后的模型反过来影响RPN的训练,然后再用改进后的RPN提议去提升Fast R-CNN的表现,直到两个子网络收敛达到较好的联合性能。

通过这种方式,RPN和Fast R-CNN可以互相促进,共同优化目标检测的整体性能。在实际训练过程中,可能需要几个交替迭代周期才能使得RPN产生的提议质量和最终检测结果达到最优状态。

3. 基于PyTorch框架的RPN

注意:由于RPN与其他网络模型关系十分密切,难易独立分割开来。以下代码只是一个非常基础的示意,并未包含诸如锚框生成、前向传播中的空间尺寸变换、损失函数定义以及后处理步骤(如非极大值抑制NMS)等内容。在实际项目中,请参考Faster R-CNN论文或其他开源实现来完善整个RPN模块的功能。

import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
 
# RPN网络定义
class RPN(nn.Module):  #softmax???   长宽计算方法??
    def __init__(self, in_channels=512):
        super(RPN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.cls_score = nn.Conv2d(512, 2 * 9, kernel_size=1)  #k=9
        self.bbox_pred = nn.Conv2d(512, 4 * 9, kernel_size=1)
 
    def forward(self, x):
        x = self.relu(self.conv1(x))
        rpn_cls_score = self.cls_score(x)  #生成预测正负样本(有无物体的概率)[pos_label, neg_label]
        rpn_cls_score_softmax = F.softmax(rpn_cls_score)
        rpn_bbox_pred = self.bbox_pred(x)  #生成bounding box的[tx, ty, tw, th]
        return rpn_cls_score_softmax, rpn_bbox_pred
 
# 使用预训练的VGG16作为特征提取器
backbone = vgg16(pretrained=True).features[:-1]  # 模型微调,去掉最后一个池化层
rpn = RPN(in_channels=512)  # VGG16最后一层输出维度为512
 
# 定义损失函数和优化器
cls_loss_func = nn.CrossEntropyLoss()
bbox_loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.SGD(rpn.parameters(), lr=0.001, momentum=0.9)
 
# 假设我们有一个训练数据加载器,格式为images, (gt_boxes, gt_labels)
data_loader = ...
 
# 训练过程,这里仅说明RPN的训练过程,即交替训练的第一步!
num_epochs = 10
for epoch in range(num_epochs):
    for images, (gt_boxes, gt_labels) in enumerate(data_loader):
        # 前向传播
        features = backbone(images)
        rpn_cls_scores_softmax, rpn_bbox_preds = rpn(features)
 
        # 数据预处理,将预测结果调整到与ground truth相匹配的格式
        # 这部分会根据你的具体实现有所不同,这里仅作示例
        rpn_cls_scores_view = rpn_cls_scores_softmax.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_bbox_preds_view = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous().view(-1, 4)
        gt_labels_view = gt_labels.view(-1)
 
        # 计算损失
        rpn_cls_loss = cls_loss_func(rpn_cls_scores_view, gt_labels_view)
        rpn_bbox_loss = bbox_loss_func(rpn_bbox_preds_view, gt_boxes)
 
        # 总损失,这里暂时忽略Ncls,Nreg,λ
        loss = rpn_cls_loss + rpn_bbox_loss
 
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        # 打印训练信息
        if (images + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{images+1}/{len(data_loader)}], Loss: {loss.item()}')
 
# 训练结束
print('Training finished.')


相关文章
|
6天前
|
机器学习/深度学习 数据采集 自然语言处理
理解并应用机器学习算法:神经网络深度解析
【5月更文挑战第15天】本文深入解析了神经网络的基本原理和关键组成,包括神经元、层、权重、偏置及损失函数。介绍了神经网络在图像识别、NLP等领域的应用,并涵盖了从数据预处理、选择网络结构到训练与评估的实践流程。理解并掌握这些知识,有助于更好地运用神经网络解决实际问题。随着技术发展,神经网络未来潜力无限。
|
1天前
|
机器学习/深度学习 存储 并行计算
深入解析xLSTM:LSTM架构的演进及PyTorch代码实现详解
xLSTM的新闻大家可能前几天都已经看过了,原作者提出更强的xLSTM,可以将LSTM扩展到数十亿参数规模,我们今天就来将其与原始的lstm进行一个详细的对比,然后再使用Pytorch实现一个简单的xLSTM。
16 2
|
2天前
|
域名解析 网络协议 网络性能优化
如何提升自建DNS服务下的网络体验
网络质量和网络体验是通信过程中的两个不同层面,质量涉及设备上下行表现,而体验关乎端到端通信效果。衡量质量常用带宽、延迟、丢包率等指标;体验则关注可访问性,DNS解析速度和服务位置等。现代路由器能自动调整网络质量,普通用户无需过多干预。自建DNS服务时,选择权威DNS能解决可访问性,但可能不提供最优体验。AdguardHome和Clash等工具能进一步优化DNS解析,提升网络体验。
26 6
如何提升自建DNS服务下的网络体验
|
3天前
|
Linux 网络安全
CentOS系统openssh-9,网络安全大厂面试真题解析大全
CentOS系统openssh-9,网络安全大厂面试真题解析大全
|
3天前
|
Linux 网络安全 Windows
网络安全笔记-day8,DHCP部署_dhcp搭建部署,源码解析
网络安全笔记-day8,DHCP部署_dhcp搭建部署,源码解析
|
3天前
|
运维 网络协议 Linux
Docker网络_docker 网络,来看看这份超全面的《Linux运维面试题及解析》
Docker网络_docker 网络,来看看这份超全面的《Linux运维面试题及解析》
|
4天前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
10 1
|
4天前
|
机器学习/深度学习 PyTorch 算法框架/工具
神经网络基本概念以及Pytorch实现,多线程编程面试题
神经网络基本概念以及Pytorch实现,多线程编程面试题
|
5天前
|
机器学习/深度学习 算法 Go
YOLOv5网络结构解析
YOLOv5网络结构解析
|
6天前
|
机器学习/深度学习 存储 算法
卷积神经网络(CNN)的数学原理解析
卷积神经网络(CNN)的数学原理解析
35 1
卷积神经网络(CNN)的数学原理解析