手机图像去噪

简介: NAFNet是一个没有激活函数的神经网络。目前,在图像修复领域有多种类型的网络结构。

NAFNet图像修复方法介绍


NAFNet是一个没有激活函数的神经网络。目前,在图像修复领域有多种类型的网络结构。


第一,多阶段的结构。如图所示,两个u型网络中间做连接,输入尺寸为。U型网络会在左侧做降采样,特征会变为(H/2)*(W/2),再做一次降采样变为(H/4)*(W/4),然后在解码阶段将分辨率进行还原。多级网络看起来复杂,但实际实现较为简单。


第二,多尺度融合结构。它与多阶段结构的最大区别在于,多阶段结构的输入只有,但它有三种尺寸的输入,可提高网络的表达,输入的信息越多,恢复能力越强。整体结构依然属于U型网络。


第三,UNet结构。其最为简单但有效,只有输入、输出,U型结构中间有跳连接。


Block是上述三种结构的基本单元,多由卷积组成。


最左侧Restormer架构的输入是LayerNorm层,接着是一个标准的self-attention结构(3个1x1的卷积分别提取key、query、value特征,key与query特征计算相似度后经过softmax归一化)。整体的大结构上做跳连接(残差连接),下半部分也是残差结构。


PlainNet的结构更简单,由两个标准残差模块组成。Baseline相对于PlainNet,其模仿了transformer,添加了LayerNorm层,将激活函数由原先的ReLU换为GELU。


NAFNet相较于Baseline,用SimpleGate取代了CELU,用SCA取代了CA。


Channel Attention(CA)主要对输入的特征图(H、C、W的3D形状)在H和W维度上求平均值,变为向量。之后接了两个卷积再经过激活函数变为新的向量,值为0到1(也可理解为权重),将权重重新乘回输入特征,相当于在输入特征的通道维度上做了重新加权,得到新特征。


而Simple Channel Attention(SCA)取消了原先Channel Attention中间2层的卷积以及激活函数,替换为1x1的卷积操作。


Simple Gate会将HxWxC的输入平均分为两份C/2xHxW,将两者相乘得到新的特征,并以此替代激活函数。


手机图像去噪训练及评估



手机图像去噪的数据集是成对的(有噪图和清晰图),本文使用的数据集是智能手机图像去噪训练数据集SIDD(Smart Image Denoising Dataset)。训练数据的格式是常见的图片格式PNG,共有320对图,噪声图和干净图一一对应。


进入ModelScope官网,搜索“数据集-智能手机图像去噪数据集”。


点击数据预览,default子数据集下,左侧为有噪图,右侧为清晰图。


另外,我们还提供了crops子数据集,主要供训练使用。原始的320对图对分辨率较大,数据加载时会存在读取瓶颈,因此,我们将大分辨率的图裁剪为图像用于训练,提高加载速度,也提高训练速度。


去噪的核心代码如上图。图片地址可以是本地路径,也可以是网络路径。然后定义一个去噪的pipeline,包含两个参数,分别是任务和model_id。model_id在对应模型名称下方,直接复制即可。


然后将图片输入pipeline中,输出去噪后图像保存到当前路径下。


在ModelScope官网搜索NAFNet图像去噪模型,点击右上角“Notebook快速开发”,选择CPU/GPU环境,填入相关信息。


在模型详情页复制相关代码至notebook。


运行代码即可。


下面进行模型训练。


首先,指定当前的工作目录(如:‘./20230424’),如果没有工作目录则会新建。然后指定model_id,snapshot_download函数会将model_id对应的所有模型文件全部下载到cache_dir,调用Config.from_file函数可以读取其中的模型配置文件。


自动下载的模型文件如上图所示。


通过MsDataset.load下载modelscope上的数据集。SIDD指数据集名称,namespace指数据集的所属,subset_name指子数据集,可选default和crops,split指用途,可选test、validation和train。另外,还需要将数据集转换为标准的PyTorch支持的加载方式。


以上图为例,数据集名称和数据集所属分别是SIDD和huizheng。


如果需要加载本地的数据集,调用本地自己写的数据集加载类,可以参考上图代码。继承PyTorch的Dataset类,定义CustomImageDataset类,然后实现三个函数:


初始化函数__init__,负责读取配置文件以及文件路径。

__len__函数,返回数据集的长度。比如数据集中有100张图片,则返回100。

__getitem__函数,输入index,读取对应标号的图片并返回tensor。



下载使用其他数据集的方法参考上图。


训练配置的参数如上图。


batch_size_per_gpu的设置与GPU的显存相关,一般建议为。worker_per_gpu与CPU相关核数相关,影响数据集加载的速度。


优化器能够支持反向传播,将损失先变为梯度,然后反向传播梯度。学习率用于控制每次迭代的步长。weight_decay为参数正则化系数,在分类网络中可以提升网络泛化性。


参数写在了config文件中,如果想要修改参数,可以通cfg_modify_fn实现,修改方式参考上图。另外,需要将参数传入训练器,参数以字典(键值对的形式)传入。传入的参数有model_id、训练集、验证集、工作目录以及cfg_modify_fn函数,传入cfg_modify_fn函数后即可自动修改超参数。


执行训练后,执行日志如上图所示。epoch[1][87/512]表示epoch为1,共需要训练512条数据,当前为第87条;eta和iter_time分别表示剩余时间和迭代时间,data_load_time表示数据加载时间,memory表示显存消耗,loss表示损失函数。


训练后如果需要进行模型评估,可参考上图代码。其中tmp_dir=‘./20230424/output’指定训练后的模型文件路径。



验证模型训练后的结果,代码如上图。


先定义数据集,然后调用SiddImageDenoisingDataset将数据集转为PyTorch数据加载类,并将其传入字典kwargs。model=tmp_dir,即前文的’./20230424/output’。train_dataset设置为none,即没有训练数据集。将参数全部传入trainer类,最后调用trainer.evaluate方法即可进行评估,最后打印评估结果。


执行结果如上图。


使用训练好的模型进行推理的代码如上图。它与直接运行测试的代码不同点在于定义pipeline。直接运行的代码如上图中被注释的部分,直接指定model_id,会默认自动下载;而使用训练好的模型进行测试,可以将此前的model_id替换为存放模型配置文件的目录,也可以导入ImageDenoisePipeline,然后指定保存目录的路径。


实际运行时的文件目录结构类似于上图,如果指定epoch=2,则还会有epoch_2.pth等相关的文件。


output文件夹下包含上述文件,最重要的是configuration.json和pytorch_model.pt。


除了图像去噪,NAFNet还可用于图像去模糊以及图像去模糊压缩,效果如上图。


另外,如果有训练数据对,想要完成指定的任务,可以通过准备数据集、自定义数据加载类、调用trainer函数来实现,可以省去很多代码的编写工作。


相关文章
|
1月前
|
机器学习/深度学习 自然语言处理 搜索推荐
手机上0.2秒出图、当前速度之最,谷歌打造超快扩散模型MobileDiffusion
【2月更文挑战第17天】手机上0.2秒出图、当前速度之最,谷歌打造超快扩散模型MobileDiffusion
36 2
手机上0.2秒出图、当前速度之最,谷歌打造超快扩散模型MobileDiffusion
|
机器学习/深度学习 机器人 vr&ar
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF(1)
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF
289 0
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF(1)
|
编解码 算法 数据可视化
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF(2)
照片转视频,像航拍一样丝滑,NeRF原班人马打造Zip-NeRF
236 0
|
机器学习/深度学习 编解码 达摩院
一键抹去瑕疵、褶皱:深入解读达摩院高清人像美肤模型ABPN
一键抹去瑕疵、褶皱:深入解读达摩院高清人像美肤模型ABPN
209 0
|
达摩院 算法 计算机视觉
一键抹去瑕疵、褶皱:深入解读达摩院高清人像美肤模型ABPN(2)
一键抹去瑕疵、褶皱:深入解读达摩院高清人像美肤模型ABPN
341 0
|
机器学习/深度学习 传感器 算法
【指纹识别】基于模板匹配实现指纹识别附matlab代码
【指纹识别】基于模板匹配实现指纹识别附matlab代码
|
算法 数据处理 计算机视觉
砥砺的前行|基于labview的机器视觉图像处理(七)——图像双边处理自适应亮度
砥砺的前行|基于labview的机器视觉图像处理(七)——图像双边处理自适应亮度
125 0
砥砺的前行|基于labview的机器视觉图像处理(七)——图像双边处理自适应亮度
|
传感器 数据挖掘 流计算
手机行为预测
手机行为预测
115 0
手机行为预测
|
机器学习/深度学习 传感器 算法
【指纹识别】基于Gabor滤波器的指纹识别研究附matlab代码
【指纹识别】基于Gabor滤波器的指纹识别研究附matlab代码
|
机器学习/深度学习 人工智能 数据可视化
程序人生 - Nature封面:脑机接口突破,可将脑中“笔迹”转为屏幕字句,速度创纪录,准确率超高
程序人生 - Nature封面:脑机接口突破,可将脑中“笔迹”转为屏幕字句,速度创纪录,准确率超高
105 0
程序人生 - Nature封面:脑机接口突破,可将脑中“笔迹”转为屏幕字句,速度创纪录,准确率超高