手机图像去噪

简介: NAFNet是一个没有激活函数的神经网络。目前,在图像修复领域有多种类型的网络结构。

NAFNet图像修复方法介绍


NAFNet是一个没有激活函数的神经网络。目前,在图像修复领域有多种类型的网络结构。


第一,多阶段的结构。如图所示,两个u型网络中间做连接,输入尺寸为。U型网络会在左侧做降采样,特征会变为(H/2)*(W/2),再做一次降采样变为(H/4)*(W/4),然后在解码阶段将分辨率进行还原。多级网络看起来复杂,但实际实现较为简单。


第二,多尺度融合结构。它与多阶段结构的最大区别在于,多阶段结构的输入只有,但它有三种尺寸的输入,可提高网络的表达,输入的信息越多,恢复能力越强。整体结构依然属于U型网络。


第三,UNet结构。其最为简单但有效,只有输入、输出,U型结构中间有跳连接。


Block是上述三种结构的基本单元,多由卷积组成。


最左侧Restormer架构的输入是LayerNorm层,接着是一个标准的self-attention结构(3个1x1的卷积分别提取key、query、value特征,key与query特征计算相似度后经过softmax归一化)。整体的大结构上做跳连接(残差连接),下半部分也是残差结构。


PlainNet的结构更简单,由两个标准残差模块组成。Baseline相对于PlainNet,其模仿了transformer,添加了LayerNorm层,将激活函数由原先的ReLU换为GELU。


NAFNet相较于Baseline,用SimpleGate取代了CELU,用SCA取代了CA。


Channel Attention(CA)主要对输入的特征图(H、C、W的3D形状)在H和W维度上求平均值,变为向量。之后接了两个卷积再经过激活函数变为新的向量,值为0到1(也可理解为权重),将权重重新乘回输入特征,相当于在输入特征的通道维度上做了重新加权,得到新特征。


而Simple Channel Attention(SCA)取消了原先Channel Attention中间2层的卷积以及激活函数,替换为1x1的卷积操作。


Simple Gate会将HxWxC的输入平均分为两份C/2xHxW,将两者相乘得到新的特征,并以此替代激活函数。


手机图像去噪训练及评估



手机图像去噪的数据集是成对的(有噪图和清晰图),本文使用的数据集是智能手机图像去噪训练数据集SIDD(Smart Image Denoising Dataset)。训练数据的格式是常见的图片格式PNG,共有320对图,噪声图和干净图一一对应。


进入ModelScope官网,搜索“数据集-智能手机图像去噪数据集”。


点击数据预览,default子数据集下,左侧为有噪图,右侧为清晰图。


另外,我们还提供了crops子数据集,主要供训练使用。原始的320对图对分辨率较大,数据加载时会存在读取瓶颈,因此,我们将大分辨率的图裁剪为图像用于训练,提高加载速度,也提高训练速度。


去噪的核心代码如上图。图片地址可以是本地路径,也可以是网络路径。然后定义一个去噪的pipeline,包含两个参数,分别是任务和model_id。model_id在对应模型名称下方,直接复制即可。


然后将图片输入pipeline中,输出去噪后图像保存到当前路径下。


在ModelScope官网搜索NAFNet图像去噪模型,点击右上角“Notebook快速开发”,选择CPU/GPU环境,填入相关信息。


在模型详情页复制相关代码至notebook。


运行代码即可。


下面进行模型训练。


首先,指定当前的工作目录(如:‘./20230424’),如果没有工作目录则会新建。然后指定model_id,snapshot_download函数会将model_id对应的所有模型文件全部下载到cache_dir,调用Config.from_file函数可以读取其中的模型配置文件。


自动下载的模型文件如上图所示。


通过MsDataset.load下载modelscope上的数据集。SIDD指数据集名称,namespace指数据集的所属,subset_name指子数据集,可选default和crops,split指用途,可选test、validation和train。另外,还需要将数据集转换为标准的PyTorch支持的加载方式。


以上图为例,数据集名称和数据集所属分别是SIDD和huizheng。


如果需要加载本地的数据集,调用本地自己写的数据集加载类,可以参考上图代码。继承PyTorch的Dataset类,定义CustomImageDataset类,然后实现三个函数:


初始化函数__init__,负责读取配置文件以及文件路径。

__len__函数,返回数据集的长度。比如数据集中有100张图片,则返回100。

__getitem__函数,输入index,读取对应标号的图片并返回tensor。



下载使用其他数据集的方法参考上图。


训练配置的参数如上图。


batch_size_per_gpu的设置与GPU的显存相关,一般建议为。worker_per_gpu与CPU相关核数相关,影响数据集加载的速度。


优化器能够支持反向传播,将损失先变为梯度,然后反向传播梯度。学习率用于控制每次迭代的步长。weight_decay为参数正则化系数,在分类网络中可以提升网络泛化性。


参数写在了config文件中,如果想要修改参数,可以通cfg_modify_fn实现,修改方式参考上图。另外,需要将参数传入训练器,参数以字典(键值对的形式)传入。传入的参数有model_id、训练集、验证集、工作目录以及cfg_modify_fn函数,传入cfg_modify_fn函数后即可自动修改超参数。


执行训练后,执行日志如上图所示。epoch[1][87/512]表示epoch为1,共需要训练512条数据,当前为第87条;eta和iter_time分别表示剩余时间和迭代时间,data_load_time表示数据加载时间,memory表示显存消耗,loss表示损失函数。


训练后如果需要进行模型评估,可参考上图代码。其中tmp_dir=‘./20230424/output’指定训练后的模型文件路径。



验证模型训练后的结果,代码如上图。


先定义数据集,然后调用SiddImageDenoisingDataset将数据集转为PyTorch数据加载类,并将其传入字典kwargs。model=tmp_dir,即前文的’./20230424/output’。train_dataset设置为none,即没有训练数据集。将参数全部传入trainer类,最后调用trainer.evaluate方法即可进行评估,最后打印评估结果。


执行结果如上图。


使用训练好的模型进行推理的代码如上图。它与直接运行测试的代码不同点在于定义pipeline。直接运行的代码如上图中被注释的部分,直接指定model_id,会默认自动下载;而使用训练好的模型进行测试,可以将此前的model_id替换为存放模型配置文件的目录,也可以导入ImageDenoisePipeline,然后指定保存目录的路径。


实际运行时的文件目录结构类似于上图,如果指定epoch=2,则还会有epoch_2.pth等相关的文件。


output文件夹下包含上述文件,最重要的是configuration.json和pytorch_model.pt。


除了图像去噪,NAFNet还可用于图像去模糊以及图像去模糊压缩,效果如上图。


另外,如果有训练数据对,想要完成指定的任务,可以通过准备数据集、自定义数据加载类、调用trainer函数来实现,可以省去很多代码的编写工作。


相关文章
|
1月前
|
机器学习/深度学习 自然语言处理 搜索推荐
手机上0.2秒出图、当前速度之最,谷歌打造超快扩散模型MobileDiffusion
【2月更文挑战第17天】手机上0.2秒出图、当前速度之最,谷歌打造超快扩散模型MobileDiffusion
40 2
手机上0.2秒出图、当前速度之最,谷歌打造超快扩散模型MobileDiffusion
|
11月前
检测使用校准的立体摄像头拍摄的视频中的人物并确定其与摄像头的距离
检测使用校准的立体摄像头拍摄的视频中的人物,并确定他们与摄像头的距离。
109 0
|
机器学习/深度学习 机器人 vr&ar
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF(1)
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF
294 0
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF(1)
|
传感器 Web App开发 机器学习/深度学习
计算机视觉教程0-3:为何拍照会有死亡视角?详解相机矩阵与畸变
计算机视觉教程0-3:为何拍照会有死亡视角?详解相机矩阵与畸变
461 0
计算机视觉教程0-3:为何拍照会有死亡视角?详解相机矩阵与畸变
|
机器学习/深度学习 传感器 算法
【指纹识别】基于模板匹配实现指纹识别附matlab代码
【指纹识别】基于模板匹配实现指纹识别附matlab代码
|
编解码 算法 数据可视化
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF(2)
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF
241 0
|
传感器 数据挖掘 流计算
手机行为预测
手机行为预测
115 0
手机行为预测
|
机器学习/深度学习 传感器 算法
【指纹识别】基于Gabor滤波器的指纹识别研究附matlab代码
【指纹识别】基于Gabor滤波器的指纹识别研究附matlab代码
|
人工智能 自动驾驶 图形学
英伟达开发最快 NeRf 技术:数秒内将 2D 照片合成为 3D 场景
英伟达开发最快 NeRf 技术:数秒内将 2D 照片合成为 3D 场景
240 0
|
机器学习/深度学习 人工智能 自然语言处理
高保真音色媲美真人,StyleTTS为QQ浏览器「听书」语音注入情感
QQ 浏览器「听书」背后的 StyleTTS 让合成语音有了情感的温度。
252 0
高保真音色媲美真人,StyleTTS为QQ浏览器「听书」语音注入情感