【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析1

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: 【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析

一、介绍

       论文地址:https://arxiv.org/pdf/2102.02808.pdf

       代码地址:http://github.com/swz30/MPRNet

      恢复图像任务,需要在空间细节和高级上下文特征之间取得复杂的平衡。于是作者设计了一个多阶段的模型,模型首先使用编解码器架构来学习上下文的特征,然后将它们与保留局部信息的高分辨率分支结合起来。


       打个比方,我要修复一张蛇的图片,编解码器负责提取高级上下文特征,告诉模型要在蛇身上“画”鳞片,而不是羽毛或其他东西;然后高分辨率分支负责细化鳞片的图案。


       MPRNet细节很多,但最主要的创新还是“多阶段”,模型共有三个阶段,前两个阶段是编解码器子网络,用来学习较大感受野的上下文特征,最后一个阶段是高分辨率分支,用于在最终的输出图像中构建所需的纹理。作者给出了Deblurring、Denoising、Deraining三个任务的项目,三个项目的backbone是一样的,只是参数规模有所不同(Deblurring>Denoising>Deraining),下面我们以最大的Deblurring为例进行介绍。

二、使用方法

       MPRNet项目分为Deblurring、Denoising和Deraining  三个子项目。作者没有用稀奇古怪的库,也没用高级的编程技巧,非常适合拿来研究学习,使用方法也很简单,几句话技能说完。

1.推理

       (1)下载预训练模型:预训练模型分别存在三个子项目的pretrained_models文件夹,下载地址在每个pretrained_models文件夹的 README.md中,需要科学上网,我放在了网盘里:

       链接:https://pan.baidu.com/s/1sxfidMvlU_pIeO5zD1tKZg 提取码:faye

       (2)准备测试图片:将退化图片放在目录samples/input/中

       (3)执行demo.py

# 执行Deblurring
python demo.py --task  Deblurring
 
# 执行Denoising
python demo.py --task  Denoising
 
# 执行Deraining
python demo.py --task  Deraining

    (4)结果放在目录samples/output/中。

2.训练

       (1)根据Dataset文件夹内的README.md文件中的地址下载数据集。

       (2)查看training.yml是否需要修改,主要是最后的数据集地址。

       (3)执行训练

python train.py

三、MPRNet结构

       我将按照官方代码实现来介绍模型结构,一些重要模块的划分可能跟论文有区别,但是整体结构是一样的。

1.整体结构

       MPRNet官方给出的结构图如下:

图1

       这个图总体概括了MPRNet的结构,但是很多细节没有表现出来,通过阅读代码我给出更加详细的模型结构介绍。下面的图中输入统一512x512,我们以Deblurring为例,并且batch_size=1。

       整体结构图如下:

图2

       图中的三个Input都是原图,整个模型三个阶段,整体流程如下:

       1.1 输入图片采用multi-patch方式分成四份,分成左上、右上、左下、右下;


       1.2 每个patch经过一个3x3的卷积扩充维度,为的是后面能提取更丰富的特征信息;


       1.3 经过CAB(Channel Attention Block),利用注意力机制提取每个维度上的特征;


       1.4 Encoder,编码三种尺度的图像特征,提取多尺度上下文特征,同时也是提取更深层的语义特征;


       1.5 合并深特征,将四个batch的同尺度特征合并成左右两个尺度,送入Decoder;


       1.6 Decoder,提取合并后的每个尺度的特征;


       1.7 输入图片采用multi-patch方式分成两份,分成左、右;


       1.8 将左右两个batch分别与Stage1 Decoder输出的大尺度特征图送入SAM(Supervised Attention Module),SAM在训练的时候可以利用GT为当前阶段的恢复过程提供有用的控制信号;


       1.9 SAM的输出分成两部分,一部分是第二次输入的原图特征,它将继续下面的流程;一部分用于训练时的Stage1输出,可以利用GT更快更好的让模型收敛。


       2.0 经过Stage2的卷积扩充通道和CAB操作,将Stage1中的Decoder前后的特征送入Stage2的Encoder。


       2.2 经过和Stage1相似的Decoder,也产生两个部分的输出,一部分继续Stage3,一部分输出与GT算损失;


       3.1 Stage3的原图输入不在切分,目的是利用完整的上下文信息恢复图片细节。


       3.2 将原图经过卷积做升维处理;


       3.3 将Stage2中的Decoder前后的特征送入Stage3的ORSNet(Original Resolution Subnetwork),ORSNet不使用任何降采样操作,并生成空间丰富的高分辨率特征。


       3.4 最后经过一个卷积将维度降为3,输出。


代码实现:

#位置:MPRNet.py
class MPRNet(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False):
        super(MPRNet, self).__init__()
 
        act=nn.PReLU()
        self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
        self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
        self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act))
 
        # Cross Stage Feature Fusion (CSFF)
        self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
        self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
 
        self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
        self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
 
        self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab)
 
        self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)
        self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
        
        self.concat12  = conv(n_feat*2, n_feat, kernel_size, bias=bias)
        self.concat23  = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias)
        self.tail     = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias)
 
    def forward(self, x3_img):
        # Original-resolution Image for Stage 3
        H = x3_img.size(2)
        W = x3_img.size(3)
 
        # Multi-Patch Hierarchy: Split Image into four non-overlapping patches
 
        # Two Patches for Stage 2
        x2top_img  = x3_img[:,:,0:int(H/2),:]
        x2bot_img  = x3_img[:,:,int(H/2):H,:]
 
        # Four Patches for Stage 1
        x1ltop_img = x2top_img[:,:,:,0:int(W/2)]
        x1rtop_img = x2top_img[:,:,:,int(W/2):W]
        x1lbot_img = x2bot_img[:,:,:,0:int(W/2)]
        x1rbot_img = x2bot_img[:,:,:,int(W/2):W]
 
        ##-------------------------------------------
        ##-------------- Stage 1---------------------
        ##-------------------------------------------
        ## Compute Shallow Features
        x1ltop = self.shallow_feat1(x1ltop_img)
        x1rtop = self.shallow_feat1(x1rtop_img)
        x1lbot = self.shallow_feat1(x1lbot_img)
        x1rbot = self.shallow_feat1(x1rbot_img)
        
        ## Process features of all 4 patches with Encoder of Stage 1
        feat1_ltop = self.stage1_encoder(x1ltop)
        feat1_rtop = self.stage1_encoder(x1rtop)
        feat1_lbot = self.stage1_encoder(x1lbot)
        feat1_rbot = self.stage1_encoder(x1rbot)
        
        ## Concat deep features
        feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)]
        feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)]
        
        ## Pass features through Decoder of Stage 1
        res1_top = self.stage1_decoder(feat1_top)
        res1_bot = self.stage1_decoder(feat1_bot)
 
        ## Apply Supervised Attention Module (SAM)
        x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img)
        x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img)
 
        ## Output image at Stage 1
        stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) 
        ##-------------------------------------------
        ##-------------- Stage 2---------------------
        ##-------------------------------------------
        ## Compute Shallow Features
        x2top  = self.shallow_feat2(x2top_img)
        x2bot  = self.shallow_feat2(x2bot_img)
 
        ## Concatenate SAM features of Stage 1 with shallow features of Stage 2
        x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1))
        x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1))
 
        ## Process features of both patches with Encoder of Stage 2
        feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top)
        feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot)
 
        ## Concat deep features
        feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)]
 
        ## Pass features through Decoder of Stage 2
        res2 = self.stage2_decoder(feat2)
 
        ## Apply SAM
        x3_samfeats, stage2_img = self.sam23(res2[0], x3_img)
 
 
        ##-------------------------------------------
        ##-------------- Stage 3---------------------
        ##-------------------------------------------
        ## Compute Shallow Features
        x3     = self.shallow_feat3(x3_img)
 
        ## Concatenate SAM features of Stage 2 with shallow features of Stage 3
        x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1))
        
        x3_cat = self.stage3_orsnet(x3_cat, feat2, res2)
 
        stage3_img = self.tail(x3_cat)
 
        return [stage3_img+x3_img, stage2_img, stage1_img]

       图中还有一些模块细节没有表现出来,下面我将详细介绍。

2.CAB(Channel Attention Block)

       顾名思义,CAB就是利用注意力机制提取每个通道的特征,输出输入特征图形状不变,结构图如下:

图3

       可以看到,经过了两个卷积和GAP之后得到了一个概率图(就是那个残差边),在经过两个卷积和Sigmoid之后与概率图相乘,就实现了一个通道注意力机制。

       代码实现:

# 位置MPRNet.py
## Channel Attention Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
                nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y
 
 
##########################################################################
## Channel Attention Block (CAB)
class CAB(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, bias, act):
        super(CAB, self).__init__()
        modules_body = []
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
        modules_body.append(act)
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
 
        self.CA = CALayer(n_feat, reduction, bias=bias)
        self.body = nn.Sequential(*modules_body)
 
    def forward(self, x):
        res = self.body(x)
        res = self.CA(res)
        res += x
        return res


【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析2:https://developer.aliyun.com/article/1507775

相关文章
|
26天前
|
设计模式 Java 关系型数据库
【Java笔记+踩坑汇总】Java基础+JavaWeb+SSM+SpringBoot+SpringCloud+瑞吉外卖/谷粒商城/学成在线+设计模式+面试题汇总+性能调优/架构设计+源码解析
本文是“Java学习路线”专栏的导航文章,目标是为Java初学者和初中高级工程师提供一套完整的Java学习路线。
219 37
|
23天前
|
敏捷开发 安全 测试技术
软件测试的艺术:从代码到用户体验的全方位解析
本文将深入探讨软件测试的重要性和实施策略,通过分析不同类型的测试方法和工具,展示如何有效地提升软件质量和用户满意度。我们将从单元测试、集成测试到性能测试等多个角度出发,详细解释每种测试方法的实施步骤和最佳实践。此外,文章还将讨论如何通过持续集成和自动化测试来优化测试流程,以及如何建立有效的测试团队来应对快速变化的市场需求。通过实际案例的分析,本文旨在为读者提供一套系统而实用的软件测试策略,帮助读者在软件开发过程中做出更明智的决策。
http数据包抓包解析课程笔记
http数据包抓包解析课程笔记
|
13天前
|
SQL 人工智能 机器人
遇到的代码部份解析
/ 模拟后端返回的数据
14 0
|
14天前
|
设计模式 存储 算法
PHP中的设计模式:策略模式的深入解析与应用在软件开发的浩瀚海洋中,PHP以其独特的魅力和强大的功能吸引了无数开发者。作为一门历史悠久且广泛应用的编程语言,PHP不仅拥有丰富的内置函数和扩展库,还支持面向对象编程(OOP),为开发者提供了灵活而强大的工具集。在PHP的众多特性中,设计模式的应用尤为引人注目,它们如同精雕细琢的宝石,镶嵌在代码的肌理之中,让程序更加优雅、高效且易于维护。今天,我们就来深入探讨PHP中使用频率颇高的一种设计模式——策略模式。
本文旨在深入探讨PHP中的策略模式,从定义到实现,再到应用场景,全面剖析其在PHP编程中的应用价值。策略模式作为一种行为型设计模式,允许在运行时根据不同情况选择不同的算法或行为,极大地提高了代码的灵活性和可维护性。通过实例分析,本文将展示如何在PHP项目中有效利用策略模式来解决实际问题,并提升代码质量。
|
2月前
|
开发者 图形学 Java
揭秘Unity物理引擎核心技术:从刚体动力学到关节连接,全方位教你如何在虚拟世界中重现真实物理现象——含实战代码示例与详细解析
【8月更文挑战第31天】Unity物理引擎对于游戏开发至关重要,它能够模拟真实的物理效果,如刚体运动、碰撞检测及关节连接等。通过Rigidbody和Collider组件,开发者可以轻松实现物体间的互动与碰撞。本文通过具体代码示例介绍了如何使用Unity物理引擎实现物体运动、施加力、使用关节连接以及模拟弹簧效果等功能,帮助开发者提升游戏的真实感与沉浸感。
42 1
|
2月前
|
开发者 图形学 API
从零起步,深度揭秘:运用Unity引擎及网络编程技术,一步步搭建属于你的实时多人在线对战游戏平台——详尽指南与实战代码解析,带你轻松掌握网络化游戏开发的核心要领与最佳实践路径
【8月更文挑战第31天】构建实时多人对战平台是技术与创意的结合。本文使用成熟的Unity游戏开发引擎,从零开始指导读者搭建简单的实时对战平台。内容涵盖网络架构设计、Unity网络API应用及客户端与服务器通信。首先,创建新项目并选择适合多人游戏的模板,使用推荐的网络传输层。接着,定义基本玩法,如2D多人射击游戏,创建角色预制件并添加Rigidbody2D组件。然后,引入网络身份组件以同步对象状态。通过示例代码展示玩家控制逻辑,包括移动和发射子弹功能。最后,设置服务器端逻辑,处理客户端连接和断开。本文帮助读者掌握构建Unity多人对战平台的核心知识,为进一步开发打下基础。
74 0
|
2月前
|
开发者 图形学 C#
揭秘游戏沉浸感的秘密武器:深度解析Unity中的音频设计技巧,从背景音乐到动态音效,全面提升你的游戏氛围艺术——附实战代码示例与应用场景指导
【8月更文挑战第31天】音频设计在游戏开发中至关重要,不仅能增强沉浸感,还能传递信息,构建氛围。Unity作为跨平台游戏引擎,提供了丰富的音频处理功能,助力开发者轻松实现复杂音效。本文将探讨如何利用Unity的音频设计提升游戏氛围,并通过具体示例代码展示实现过程。例如,在恐怖游戏中,阴森的背景音乐和突然的脚步声能增加紧张感;在休闲游戏中,轻快的旋律则让玩家感到愉悦。
51 0
|
2月前
|
开发者 图形学 C#
深度解密:Unity游戏开发中的动画艺术——Mecanim状态机如何让游戏角色栩栩如生:从基础设置到高级状态切换的全面指南,助你打造流畅自然的游戏动画体验
【8月更文挑战第31天】Unity动画系统是游戏开发的关键部分,尤其适用于复杂角色动画。本文通过具体案例讲解Mecanim动画状态机的使用方法及原理。我们创建一个游戏角色并设计行走、奔跑和攻击动画,详细介绍动画状态机设置及脚本控制。首先导入动画资源并添加Animator组件,然后创建Animator Controller并设置状态间的转换条件。通过编写C#脚本(如PlayerMovement)控制动画状态切换,实现基于玩家输入的动画过渡。此方法不仅适用于游戏角色,还可用于任何需动态动画响应的对象,增强游戏的真实感与互动性。
59 0
|
2月前
|
图形学 开发者
【Unity光照艺术手册】掌握这些技巧,让你的游戏场景瞬间提升档次:从基础光源到全局光照,打造24小时不间断的视觉盛宴——如何运用代码与烘焙创造逼真光影效果全解析
【8月更文挑战第31天】在Unity中,合理的光照与阴影设置对于打造逼真环境至关重要。本文介绍Unity支持的多种光源类型,如定向光、点光源、聚光灯等,并通过具体示例展示如何使用着色器和脚本控制光照强度,模拟不同时间段的光照变化。此外,还介绍了动态和静态阴影、全局光照及光照探针等高级功能,帮助开发者创造丰富多样的光影效果,提升游戏沉浸感。
41 0

热门文章

最新文章

推荐镜像

更多