超好用!图像去雾算法C2PNet介绍与使用指南

简介: 超好用!图像去雾算法C2PNet介绍与使用指南


引言

本文主要介绍一个开源的C2PNet去雾算法的使用。这篇文章主要研究了单图像去雾问题,并提出了一个新的去雾网络C2PNet。C2PNet使用了课程对比正则化和物理感知的双分支单元来提高去雾模型的解释性和性能。文章首先介绍了去雾问题的挑战,然后详细阐述了C2PNet的设计原理和实现方法,包括物理感知双分支单元和共识负样本对比正则化。最后,通过在合成数据集和真实世界数据集上的定量评估,证明了C2PNet在去雾性能上的优越性。

去雾效果

基本原理

C2PNet(Curricular Contrastive Regularization for Physics-aware Single Image Dehazing Network)的设计原理和实现方法如下:

1. 物理感知双分支单元(Physics-aware Dual-branch Unit, PDU):

  • 设计背景:传统的单张图像去雾方法要么在原始空间直接估计未知因素(传输图和大气光),要么在特征空间中忽略这些物理特性。为了结合物理模型的优势并避免累积误差,设计了PDU。
  • 原理:基于大气散射模型,PDU将传输图和大气光的估计任务分配给两个并行的分支,每个分支分别学习与相应物理因素相关的特征表示。这样可以更精确地合成符合物理模型的潜在清晰图像。
  • 实现:通过一系列卷积层和非线性激活函数,两个分支各自提取图像特征,然后通过加权求和的方式结合这些特征,最终得到去雾化的图像输出。

2. 共识负样本对比正则化(Consensual Contrastive Regularization, CR):

  • 设计背景:为了提高特征空间的可解释性和引导网络学习更有区分度的特征表示,引入了基于对比学习的正则化方法。
  • 原理:对比学习的核心思想是区分正样本对(锚点与其正样本)和负样本对(锚点与其负样本)。在去雾任务中,正样本对是指同一场景下的清晰图像和模糊图像,负样本对则是不同场景下的清晰图像。通过最小化正样本之间的距离和最大化负样本之间的距离,可以约束解决方案的空间,从而提高去雾效果。
  • 实现:在训练过程中,对于每个清晰图像和对应的模糊图像对,网络会学习一个对比损失函数。这个损失函数会随着训练进程动态调整,以平衡正样本和负样本的贡献。具体来说,容易区分的样本(例如,PSNR大于30的样本)会被视为“容易的”样本,而其他样本则被视为“非容易的”样本,并给予更高的权重。这样,网络就会首先学习到容易样本的特征,然后再逐渐聚焦于更难以区分的样本,从而实现一种渐进式的学习策略。

3. 网络整体结构:

  • C2PNet由多个PDU模块串联而成,形成一个多阶段的去雾网络。每个PDU模块负责处理输入图像的不同分辨率版本,从而逐步恢复出高分辨率的清晰图像。在每个PDU模块中,都会对输入图像执行上采样操作,以逐渐重建出全分辨率的清晰图像。

4. 训练策略:

  • 除了传统的L1损失用于直接衡量网络预测的去雾图像和真实清晰图像之间的差异外,C2PNet还采用了CR作为正则化手段来提升特征学习的质量。
  • 在训练过程中,网络会为每个清晰图像和其对应的模糊图像对生成对比损失。同时,为了使网络能够从难易程度不同的样本中学习到有效的特征表示,网络会根据样本的难度动态调整对比损失的权重。

5. 实现细节:

  • C2PNet使用PyTorch 1.11.0在NVIDIA RTX 3090 GPU上实现。
  • 为了评估C2PNet的效果,论文中使用了多个合成数据集和真实世界的去雾数据集,并与其他几种先进的去雾算法进行了比较。实验结果表明,C2PNet在各种数据集上均取得了领先的性能。

总结,C2PNet的设计原理在于结合物理模型和对比学习正则化来提升单张图像去雾的效果。通过物理感知双分支单元对传输图和大气光进行分别建模,以及通过共识负样本对比正则化引导网络学习更加鲁棒和有区分度的特征表示,C2PNet能够在去雾任务中取得显著的性能提升。

模型结果对比

模型使用完整代码

我们直接使用onnx模型进行图片去雾推理,将图片对比结果存入results目录中:

import argparse
import os
import cv2
import onnxruntime
import numpy as np
class C2PNet:
    def __init__(self, modelpath):
        # Initialize model
        self.onnx_session = onnxruntime.InferenceSession(modelpath)
        self.input_name = self.onnx_session.get_inputs()[0].name
        _, _, self.input_height, self.input_width = self.onnx_session.get_inputs()[0].shape
    def detect(self, image):
        input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if isinstance(self.input_height ,int) and isinstance(self.input_width, int):
            input_image = cv2.resize(input_image, (self.input_width, self.input_height)) ###固定输入分辨率, HXW.onnx文件是动态输入分辨率的
        input_image = input_image.astype(np.float32) / 255.0
        input_image = input_image.transpose(2, 0, 1)
        input_image = np.expand_dims(input_image, axis=0)
        result = self.onnx_session.run(None, {self.input_name: input_image}) ###opencv-dnn推理时,结果图全黑
        
        # Post process:squeeze, RGB->BGR, Transpose, uint8 cast
        output_image = np.squeeze(result[0])
        output_image = output_image.transpose(1, 2, 0)
        output_image = output_image * 255
        output_image = np.clip(output_image, 0, 255)
        output_image = output_image.astype(np.uint8)
        output_image = cv2.cvtColor(output_image.astype(np.uint8), cv2.COLOR_RGB2BGR)
        output_image = cv2.resize(output_image, (image.shape[1], image.shape[0]))
        return output_image
if __name__ == '__main__':
    path = 'testimgs/outdoor'
    for each in os.listdir(path):
        parser = argparse.ArgumentParser()
        parser.add_argument('--imgpath', type=str,
                            default=os.path.join(path, each), help="image path")
        parser.add_argument('--modelpath', type=str,
                            default='weights/c2pnet_outdoor_HxW.onnx', help="onnx path")
        args = parser.parse_args()
        mynet = C2PNet(args.modelpath)
        srcimg = cv2.imread(args.imgpath)
        dstimg = mynet.detect(srcimg)
        if srcimg.shape[0] > srcimg.shape[1]:
            boundimg = np.zeros((10, srcimg.shape[1], 3), dtype=srcimg.dtype)+255  ###中间分开原图和结果
            combined_img = np.vstack([srcimg, boundimg, dstimg])
        else:
            boundimg = np.zeros((srcimg.shape[0], 10, 3), dtype=srcimg.dtype)+255
            combined_img = np.hstack([srcimg, boundimg, dstimg])
        cv2.imwrite(os.path.join('results',each), combined_img)
        winName = 'Deep learning Image Dehaze use onnxruntime'
        cv2.namedWindow(winName, 0)
        cv2.imshow(winName, combined_img)  ###原图和结果图也可以分开窗口显示
        cv2.waitKey(0)
        cv2.destroyAllWindows()

运行后,保存结果如下:


相关文章
|
1月前
|
机器学习/深度学习 算法 机器人
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
195 0
|
3月前
|
编解码 算法
改进SIFT算法实现光学图像和SAR图像配准
改进SIFT算法实现光学图像和SAR图像配准
|
1月前
|
机器学习/深度学习 算法 自动驾驶
基于导向滤波的暗通道去雾算法在灰度与彩色图像可见度复原中的研究(Matlab代码实现)
基于导向滤波的暗通道去雾算法在灰度与彩色图像可见度复原中的研究(Matlab代码实现)
146 8
|
2月前
|
存储 监控 算法
基于文化优化算法图像量化(Matlab代码实现)
基于文化优化算法图像量化(Matlab代码实现)
101 1
|
2月前
|
存储 算法 生物认证
基于Zhang-Suen算法的图像细化处理FPGA实现,包含testbench和matlab验证程序
本项目基于Zhang-Suen算法实现图像细化处理,支持FPGA与MATLAB双平台验证。通过对比,FPGA细化效果与MATLAB一致,可有效减少图像数据量,便于后续识别与矢量化处理。算法适用于字符识别、指纹识别等领域,配套完整仿真代码及操作说明。
|
2月前
|
机器学习/深度学习 监控 并行计算
【图像增强】局部对比度增强的CLAHE算法直方图增强研究(Matlab代码实现)
【图像增强】局部对比度增强的CLAHE算法直方图增强研究(Matlab代码实现)
302 0
|
4月前
|
机器学习/深度学习 监控 算法
基于单尺度Retinex和多尺度Retinex的图像增强算法实现
基于单尺度Retinex(SSR)和多尺度Retinex(MSR)的图像增强算法实现
416 1
|
4月前
|
存储 算法 数据安全/隐私保护
基于FPGA的图像退化算法verilog实现,分别实现横向和纵向运动模糊,包括tb和MATLAB辅助验证
本项目基于FPGA实现图像运动模糊算法,包含横向与纵向模糊处理流程。使用Vivado 2019.2与MATLAB 2022A,通过一维卷积模拟点扩散函数,完成图像退化处理,并可在MATLAB中预览效果。
|
4月前
|
监控 算法 决策智能
基于盲源分离与贝叶斯非局部均值的图像降噪算法
基于盲源分离与贝叶斯非局部均值的图像降噪算法
155 0
|
5月前
|
算法 数据安全/隐私保护
基于混沌加密的遥感图像加密算法matlab仿真
本项目实现了一种基于混沌加密的遥感图像加密算法MATLAB仿真(测试版本:MATLAB2022A)。通过Logistic映射与Baker映射生成混沌序列,对遥感图像进行加密和解密处理。程序分析了加解密后图像的直方图、像素相关性、信息熵及解密图像质量等指标。结果显示,加密图像具有良好的随机性和安全性,能有效保护遥感图像中的敏感信息。该算法适用于军事、环境监测等领域,具备加密速度快、密钥空间大、安全性高的特点。

热门文章

最新文章