NAFNet图像修复方法介绍
NAFNet是一个没有激活函数的神经网络。目前,在图像修复领域有多种类型的网络结构。
第一,多阶段的结构。如图所示,两个u型网络中间做连接,输入尺寸为。U型网络会在左侧做降采样,特征会变为(H/2)*(W/2),再做一次降采样变为(H/4)*(W/4),然后在解码阶段将分辨率进行还原。多级网络看起来复杂,但实际实现较为简单。
第二,多尺度融合结构。它与多阶段结构的最大区别在于,多阶段结构的输入只有,但它有三种尺寸的输入,可提高网络的表达,输入的信息越多,恢复能力越强。整体结构依然属于U型网络。
第三,UNet结构。其最为简单但有效,只有输入、输出,U型结构中间有跳连接。
Block是上述三种结构的基本单元,多由卷积组成。
最左侧Restormer架构的输入是LayerNorm层,接着是一个标准的self-attention结构(3个1x1的卷积分别提取key、query、value特征,key与query特征计算相似度后经过softmax归一化)。整体的大结构上做跳连接(残差连接),下半部分也是残差结构。
PlainNet的结构更简单,由两个标准残差模块组成。Baseline相对于PlainNet,其模仿了transformer,添加了LayerNorm层,将激活函数由原先的ReLU换为GELU。
NAFNet相较于Baseline,用SimpleGate取代了CELU,用SCA取代了CA。
Channel Attention(CA)主要对输入的特征图(H、C、W的3D形状)在H和W维度上求平均值,变为向量。之后接了两个卷积再经过激活函数变为新的向量,值为0到1(也可理解为权重),将权重重新乘回输入特征,相当于在输入特征的通道维度上做了重新加权,得到新特征。
而Simple Channel Attention(SCA)取消了原先Channel Attention中间2层的卷积以及激活函数,替换为1x1的卷积操作。
Simple Gate会将HxWxC的输入平均分为两份C/2xHxW,将两者相乘得到新的特征,并以此替代激活函数。
手机图像去噪训练及评估
手机图像去噪的数据集是成对的(有噪图和清晰图),本文使用的数据集是智能手机图像去噪训练数据集SIDD(Smart Image Denoising Dataset)。训练数据的格式是常见的图片格式PNG,共有320对图,噪声图和干净图一一对应。
进入ModelScope官网,搜索“数据集-智能手机图像去噪数据集”。
点击数据预览,default子数据集下,左侧为有噪图,右侧为清晰图。
另外,我们还提供了crops子数据集,主要供训练使用。原始的320对图对分辨率较大,数据加载时会存在读取瓶颈,因此,我们将大分辨率的图裁剪为图像用于训练,提高加载速度,也提高训练速度。
去噪的核心代码如上图。图片地址可以是本地路径,也可以是网络路径。然后定义一个去噪的pipeline,包含两个参数,分别是任务和model_id。model_id在对应模型名称下方,直接复制即可。
然后将图片输入pipeline中,输出去噪后图像保存到当前路径下。
在ModelScope官网搜索NAFNet图像去噪模型,点击右上角“Notebook快速开发”,选择CPU/GPU环境,填入相关信息。
在模型详情页复制相关代码至notebook。
运行代码即可。
下面进行模型训练。
首先,指定当前的工作目录(如:‘./20230424’),如果没有工作目录则会新建。然后指定model_id,snapshot_download函数会将model_id对应的所有模型文件全部下载到cache_dir,调用Config.from_file函数可以读取其中的模型配置文件。
自动下载的模型文件如上图所示。
通过MsDataset.load下载modelscope上的数据集。SIDD指数据集名称,namespace指数据集的所属,subset_name指子数据集,可选default和crops,split指用途,可选test、validation和train。另外,还需要将数据集转换为标准的PyTorch支持的加载方式。
以上图为例,数据集名称和数据集所属分别是SIDD和huizheng。
如果需要加载本地的数据集,调用本地自己写的数据集加载类,可以参考上图代码。继承PyTorch的Dataset类,定义CustomImageDataset类,然后实现三个函数:
●初始化函数__init__,负责读取配置文件以及文件路径。
●__len__函数,返回数据集的长度。比如数据集中有100张图片,则返回100。
●__getitem__函数,输入index,读取对应标号的图片并返回tensor。
下载使用其他数据集的方法参考上图。
训练配置的参数如上图。
batch_size_per_gpu的设置与GPU的显存相关,一般建议为。worker_per_gpu与CPU相关核数相关,影响数据集加载的速度。
优化器能够支持反向传播,将损失先变为梯度,然后反向传播梯度。学习率用于控制每次迭代的步长。weight_decay为参数正则化系数,在分类网络中可以提升网络泛化性。
参数写在了config文件中,如果想要修改参数,可以通cfg_modify_fn实现,修改方式参考上图。另外,需要将参数传入训练器,参数以字典(键值对的形式)传入。传入的参数有model_id、训练集、验证集、工作目录以及cfg_modify_fn函数,传入cfg_modify_fn函数后即可自动修改超参数。
执行训练后,执行日志如上图所示。epoch[1][87/512]表示epoch为1,共需要训练512条数据,当前为第87条;eta和iter_time分别表示剩余时间和迭代时间,data_load_time表示数据加载时间,memory表示显存消耗,loss表示损失函数。
训练后如果需要进行模型评估,可参考上图代码。其中tmp_dir=‘./20230424/output’指定训练后的模型文件路径。
验证模型训练后的结果,代码如上图。
先定义数据集,然后调用SiddImageDenoisingDataset将数据集转为PyTorch数据加载类,并将其传入字典kwargs。model=tmp_dir,即前文的’./20230424/output’。train_dataset设置为none,即没有训练数据集。将参数全部传入trainer类,最后调用trainer.evaluate方法即可进行评估,最后打印评估结果。
执行结果如上图。
使用训练好的模型进行推理的代码如上图。它与直接运行测试的代码不同点在于定义pipeline。直接运行的代码如上图中被注释的部分,直接指定model_id,会默认自动下载;而使用训练好的模型进行测试,可以将此前的model_id替换为存放模型配置文件的目录,也可以导入ImageDenoisePipeline,然后指定保存目录的路径。
实际运行时的文件目录结构类似于上图,如果指定epoch=2,则还会有epoch_2.pth等相关的文件。
output文件夹下包含上述文件,最重要的是configuration.json和pytorch_model.pt。
除了图像去噪,NAFNet还可用于图像去模糊以及图像去模糊压缩,效果如上图。
另外,如果有训练数据对,想要完成指定的任务,可以通过准备数据集、自定义数据加载类、调用trainer函数来实现,可以省去很多代码的编写工作。