【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析2

本文涉及的产品
云解析 DNS,旗舰版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
简介: 【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析

【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析1:https://developer.aliyun.com/article/1507772

3.Stage1 Encoder

       Stage1和Stage1的Encoder有一些区别,所以分开介绍。Stage1 Encoder有一个输入和三个不同尺度的输出,为的是提取三个尺度的特征并为下面的尺度融合流程做准备;其中有多个CAB结构,可以更好的提取通道特征;下采样通过粗暴的Downsample实现,结构如下:

 图4

       代码实现:

# 位置MPRNet.py
class Encoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
        super(Encoder, self).__init__()
 
        self.encoder_level1 = [CAB(n_feat,                     kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
        self.encoder_level2 = [CAB(n_feat+scale_unetfeats,     kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
        self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
 
        self.encoder_level1 = nn.Sequential(*self.encoder_level1)
        self.encoder_level2 = nn.Sequential(*self.encoder_level2)
        self.encoder_level3 = nn.Sequential(*self.encoder_level3)
 
        self.down12  = DownSample(n_feat, scale_unetfeats)
        self.down23  = DownSample(n_feat+scale_unetfeats, scale_unetfeats)
 
        # Cross Stage Feature Fusion (CSFF)
        if csff:
            self.csff_enc1 = nn.Conv2d(n_feat,                     n_feat,                     kernel_size=1, bias=bias)
            self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats,     n_feat+scale_unetfeats,     kernel_size=1, bias=bias)
            self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
 
            self.csff_dec1 = nn.Conv2d(n_feat,                     n_feat,                     kernel_size=1, bias=bias)
            self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats,     n_feat+scale_unetfeats,     kernel_size=1, bias=bias)
            self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)
 
    def forward(self, x, encoder_outs=None, decoder_outs=None):
        enc1 = self.encoder_level1(x)
        if (encoder_outs is not None) and (decoder_outs is not None):
            enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
 
        x = self.down12(enc1)
 
        enc2 = self.encoder_level2(x)
        if (encoder_outs is not None) and (decoder_outs is not None):
            enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
 
        x = self.down23(enc2)
 
        enc3 = self.encoder_level3(x)
        if (encoder_outs is not None) and (decoder_outs is not None):
            enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
        
        return [enc1, enc2, enc3]

4.Stage2 Encoder

       Stage2 Encoder输入为三个,分别为上一层的输出和Stage1中的Decoder前后的特征。主流程(也就是左面竖着的那一列)和Stage1 Encoder是一样的。增加的两个输入,每个输入又分为三个尺度,每个尺度经过一个卷积层,然后相同尺度的特征图做特征融合,输出,结构如下:

图5

5.Decoder

       两个阶段的Decoder结构是一样的,所以放在一起说,有三个不用尺度的输入;通过CAB提取特征;小尺度特征通过上采样变大,通过卷积使通道变小;小尺度的特征图shape最终变成跟大尺度一样,通过残差边实现特征融合,结构如下:

图6

       代码实现:

# 位置:MPRNet.py
class Decoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
        super(Decoder, self).__init__()
 
        self.decoder_level1 = [CAB(n_feat,                     kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
        self.decoder_level2 = [CAB(n_feat+scale_unetfeats,     kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
        self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]
 
        self.decoder_level1 = nn.Sequential(*self.decoder_level1)
        self.decoder_level2 = nn.Sequential(*self.decoder_level2)
        self.decoder_level3 = nn.Sequential(*self.decoder_level3)
 
        self.skip_attn1 = CAB(n_feat,                 kernel_size, reduction, bias=bias, act=act)
        self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
 
        self.up21  = SkipUpSample(n_feat, scale_unetfeats)
        self.up32  = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)
 
    def forward(self, outs):
        enc1, enc2, enc3 = outs
        dec3 = self.decoder_level3(enc3)
 
        x = self.up32(dec3, self.skip_attn2(enc2))
        dec2 = self.decoder_level2(x)
 
        x = self.up21(dec2, self.skip_attn1(enc1))
        dec1 = self.decoder_level1(x)
 
        return [dec1,dec2,dec3]

6.SAM(Supervised Attention Module)

       SAM出现在两个阶段间,有两个输入,将上一层特征和原图作为输入,提升了特征提取的性能,,SAM作为有监督的注意模块,使用注意力图强力筛选了跨阶段间的有用特征。有两个输出,一个是经过了注意力机制的特征图,为下面的流程提供特征;一个是3通道的图片特征,为了训练阶段输出,结构如下:

图7

       代码位置:

# 位置MPRNet.py
## Supervised Attention Module
class SAM(nn.Module):
    def __init__(self, n_feat, kernel_size, bias):
        super(SAM, self).__init__()
        self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
        self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
 
    def forward(self, x, x_img):
        x1 = self.conv1(x)
        img = self.conv2(x) + x_img
        x2 = torch.sigmoid(self.conv3(img))
        x1 = x1*x2
        x1 = x1+x
        return x1, img

7.ORSNet(Original Resolution Subnetwork)

       为了保留输入图像的细节,模型在最后一阶段引入了原始分辨率的子网(ORSNet:Original Resolution Subnetwork)。ORSNet不使用任何降采样操作,并生成空间丰富的高分辨率特征。它由多个原始分辨率块(BRB)组成,是模型的最后阶段,结构如下:

图8

       可以看到,输入为三个,分别为上一层的输出和Stage2中的Decoder前后的特征。后两个输入,每个输入又分为三个尺度,三个尺度的通道数都先变成96,然后在变成128;小尺度的size都变成和大尺度一样,最后做特征融合融合前会经过ORB(Original Resolution Block)模块。

       ORB由一连串的CAB组成,还有一个大的残差边,结构如下:

图9

       代码实现:

# 位置MPRNet.py
## Original Resolution Block (ORB)
class ORB(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab):
        super(ORB, self).__init__()
        modules_body = []
        modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)
 
    def forward(self, x):
        res = self.body(x)
        res += x
        return res
 
##########################################################################
class ORSNet(nn.Module):
    def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab):
        super(ORSNet, self).__init__()
 
        self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
        self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
        self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab)
 
        self.up_enc1 = UpSample(n_feat, scale_unetfeats)
        self.up_dec1 = UpSample(n_feat, scale_unetfeats)
 
        self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
        self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
 
        self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
        self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
        self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
 
        self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
        self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
        self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias)
 
    def forward(self, x, encoder_outs, decoder_outs):
        x = self.orb1(x)
        x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0])
 
        x = self.orb2(x)
        x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1]))
 
        x = self.orb3(x)
        x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2]))
 
        return x

四、损失函数

       MPRNet主要使用了两个损失函数CharbonnierLoss和EdgeLoss,公式如下:


        其中累加是因为训练的时候三个阶段都有输出,都需要个GT计算损失(如图2的三个output);该模型不是直接预测恢复的图像 ,而是预测残差图像 ,添加退化的输入图像 得到:  


     Deblurring和Deraining两个任务CharbonnierLoss和EdgeLoss做了加权求和,比例1:0.05;只使用了CharbonnierLoss,我感觉是因为这里使用的噪声是某种分布(入高斯分布、泊松分布)的噪声,不会引起剧烈的边缘差异,所以Denoising没有使用EdgeLoss。

       下面简单介绍一下两种损失。

1.CharbonnierLoss

       公式如下:


  CharbonnierLoss在零点附近由于常数 的存在,梯度不会变成零,避免梯度消失。函数曲线近似L1损失,相比L2损失而言,对异常值不敏感,避免过分放大误差。

       代码实现:

# 位置losses.py
class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""
 
    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps
 
    def forward(self, x, y):
        diff = x - y
        # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
        return loss

2.EdgeLoss

       L1或者L2损失注重的是全局,没有很好地考虑一些显著特征的影响, 而显著的结构和纹理信息与人的主观感知效果高度相关,是不能忽视的。

       边缘损失主要考虑纹理部分的差异,可以很好地考虑高频的纹理结构信息, 提高生成图像的细节表现,公示如下:


    其中 表示Laplacian边缘检测中的核函数, 表示对 做边缘检测,公式中其他部分和CharbonnierLoss类似。

       代码实现:

# 位置losses.py
 
class EdgeLoss(nn.Module):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()
        self.loss = CharbonnierLoss()
 
    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
        return F.conv2d(img, self.kernel, groups=n_channels)
 
    def laplacian_kernel(self, current):
        filtered    = self.conv_gauss(current)    # filter
        down        = filtered[:,:,::2,::2]               # downsample
        new_filter  = torch.zeros_like(filtered)
        new_filter[:,:,::2,::2] = down*4                  # upsample
        filtered    = self.conv_gauss(new_filter) # filter
        diff = current - filtered
        return diff
 
    def forward(self, x, y):
        loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
        return loss
相关文章
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
Hugging Face 论文平台 Daily Papers 功能全解析
【9月更文挑战第23天】Hugging Face 是一个专注于自然语言处理领域的开源机器学习平台。其推出的 Daily Papers 页面旨在帮助开发者和研究人员跟踪 AI 领域的最新进展,展示经精心挑选的高质量研究论文,并提供个性化推荐、互动交流、搜索、分类浏览及邮件提醒等功能,促进学术合作与知识共享。
|
1月前
|
存储 安全 Java
系统安全架构的深度解析与实践:Java代码实现
【11月更文挑战第1天】系统安全架构是保护信息系统免受各种威胁和攻击的关键。作为系统架构师,设计一套完善的系统安全架构不仅需要对各种安全威胁有深入理解,还需要熟练掌握各种安全技术和工具。
91 10
|
1月前
|
前端开发 JavaScript 开发者
揭秘前端高手的秘密武器:深度解析递归组件与动态组件的奥妙,让你代码效率翻倍!
【10月更文挑战第23天】在Web开发中,组件化已成为主流。本文深入探讨了递归组件与动态组件的概念、应用及实现方式。递归组件通过在组件内部调用自身,适用于处理层级结构数据,如菜单和树形控件。动态组件则根据数据变化动态切换组件显示,适用于不同业务逻辑下的组件展示。通过示例,展示了这两种组件的实现方法及其在实际开发中的应用价值。
34 1
|
3月前
|
设计模式 Java 关系型数据库
【Java笔记+踩坑汇总】Java基础+JavaWeb+SSM+SpringBoot+SpringCloud+瑞吉外卖/谷粒商城/学成在线+设计模式+面试题汇总+性能调优/架构设计+源码解析
本文是“Java学习路线”专栏的导航文章,目标是为Java初学者和初中高级工程师提供一套完整的Java学习路线。
440 37
|
2月前
|
机器学习/深度学习 人工智能 算法
揭开深度学习与传统机器学习的神秘面纱:从理论差异到实战代码详解两者间的选择与应用策略全面解析
【10月更文挑战第10天】本文探讨了深度学习与传统机器学习的区别,通过图像识别和语音处理等领域的应用案例,展示了深度学习在自动特征学习和处理大规模数据方面的优势。文中还提供了一个Python代码示例,使用TensorFlow构建多层感知器(MLP)并与Scikit-learn中的逻辑回归模型进行对比,进一步说明了两者的不同特点。
74 2
|
2月前
|
存储 搜索推荐 数据库
运用LangChain赋能企业规章制度制定:深入解析Retrieval-Augmented Generation(RAG)技术如何革新内部管理文件起草流程,实现高效合规与个性化定制的完美结合——实战指南与代码示例全面呈现
【10月更文挑战第3天】构建公司规章制度时,需融合业务实际与管理理论,制定合规且促发展的规则体系。尤其在数字化转型背景下,利用LangChain框架中的RAG技术,可提升规章制定效率与质量。通过Chroma向量数据库存储规章制度文本,并使用OpenAI Embeddings处理文本向量化,将现有文档转换后插入数据库。基于此,构建RAG生成器,根据输入问题检索信息并生成规章制度草案,加快更新速度并确保内容准确,灵活应对法律与业务变化,提高管理效率。此方法结合了先进的人工智能技术,展现了未来规章制度制定的新方向。
36 3
|
2月前
|
编解码 算法 测试技术
Imagen论文简要解析
Imagen论文简要解析
35 0
|
2月前
|
SQL 监控 关系型数据库
SQL错误代码1303解析与处理方法
在SQL编程和数据库管理中,遇到错误代码是常有的事,其中错误代码1303在不同数据库系统中可能代表不同的含义
|
2月前
|
SQL 安全 关系型数据库
SQL错误代码1303解析与解决方案:深入理解并应对权限问题
在数据库管理和开发过程中,遇到错误代码是常见的事情,每个错误代码都代表着一种特定的问题
|
3月前
|
敏捷开发 安全 测试技术
软件测试的艺术:从代码到用户体验的全方位解析
本文将深入探讨软件测试的重要性和实施策略,通过分析不同类型的测试方法和工具,展示如何有效地提升软件质量和用户满意度。我们将从单元测试、集成测试到性能测试等多个角度出发,详细解释每种测试方法的实施步骤和最佳实践。此外,文章还将讨论如何通过持续集成和自动化测试来优化测试流程,以及如何建立有效的测试团队来应对快速变化的市场需求。通过实际案例的分析,本文旨在为读者提供一套系统而实用的软件测试策略,帮助读者在软件开发过程中做出更明智的决策。