改进UNet | 透过UCTransNet分析ResNet+UNet是不是真的有效?(一)

简介: 改进UNet | 透过UCTransNet分析ResNet+UNet是不是真的有效?(一)

1简介


最近的很多医疗语义分割方法都采用了带有编解码器结构的U-Net框架。但是U-Net采用简单的跳跃连接方案对于全局多尺度问题进行建模仍然具有挑战性:

  1. 由于编解码器阶段特征集不兼容,并不是每个跳跃连接设置都是有效的,甚至一些跳跃连接会对分割性能产生负面影响;
  2. 原有的U-Net在某些数据集上比没有跳过连接的U-Net更差。

基于研究结果,作者提出了一个新的细分框架UCTransNet(在U-Net中提出了一个CTrans模块),从通道注意力机制的视角出发。

具体来说,CTrans(Channel Transformer))模块是U-Net skip connections的替代,其中一个子模块用于与Transformer进行多尺度通道交叉融合(CCT),另一个子模块Channel-wise Cross-attention(CCA)用于引导融合的多尺度通道信息与解码器特征有效连接以消除歧义。

因此,本文提出的由CCT和CCA组成的连接能够代替原有的skip connections,解决语义空白,实现精确的医学图像自动分割。

实验结果表明,UCTransNet可以得到更精确的分割性能,并在不同数据集和传统架构(包括transformer或U-Shape框架)的语义分割方面取得了一致的改进。

本文主要贡献:

  1. 分析了skip connections在多个数据集上的有效性,表明独立简单复制是不合适的。
  2. 提出了一个新的视角来提高语义分割的性能,即通过更有效的特征融合和多尺度的通道交叉注意力来弥补low-level和high-level特征之间的语义和分辨率差距,以捕获更复杂的通道依赖。
  3. UCTransNet是第一个从通道角度重新思考Transformer自注意力机制的方法。与其他先进的分割方法相比,实验结果在公共数据集上都有更好的性能。

2Skip connection的分析


image.png

图3

发现 1

没有任何Skip connection的U-net甚至比原来的U-net更好。比较图3,可以发现“U-Net-none”在几乎所有参数的算法中表现最差MoNuSeg数据集。然而,“U-Net-none”,尽管没有任何限制,仍然在Glas数据集上取得了与“U-Net-all”非常有竞争力的性能。它表明Skip connection并不总是对语义分割有益。

发现 2

尽管UNet-all比UNet-none性能更好,但并不是所有简单复制的Skip connection都对语义分割有用。每个Skip connection的贡献是不同的。作者发现,在MoNuSeg数据集上,每个Skip connection的性能范围分别为[67.5%,76.44%]和[52.2%,62.73%]。对于不同的single skip connection,冲击变化较大。

此外,由于编码器和解码器阶段的特征集不兼容的问题,一些skip connection对分割性能有负面影响。例如,L1在Glas数据集上的Dice和IOU方面的表现比UNet-none差。这个结果并不能证明来自编码器阶段的许多特性是不能提供信息的。其背后的原因可能是简单的复制不适合特征融合。

发现 3

对于不同的数据集,skip connection的最佳组合是不同的,这取决于目标病变的规模和外观。作者进行了几个消融实验,以探索最佳侧输出设置。

注意,由于空间有限,作者忽略了两个skip connection的组合。

可以看到,skip connection并没有获得更好的性能,没有L4的模型在MoNuSeg数据集上表现最好,而令人惊讶的是,只有一个skip connection的L3在GlaS数据集上表现最好。这些观察结果表明,不同数据集的最佳组合是不同的。这进一步证实了在特征融合中引入更合适的动作而不是简单的连接的必要性。


3UCTransNet用于医学图像分割


image.png

图2 UCTransNet框架

图2展示了UCTransNet框架的概述。目前基于transformer的分割方法主要是针对U-Net的编码器进行改进,因为U-Net具有捕获远程信息的优势。这些方法,如TransUNet或TransFuse,以简单的方式将Transformer与U-Net融合,即将Transformer模块插入编码器或融合两个独立分支。但是,作者认为目前U-Net模型的潜在限制是skip connection的问题,而不是原始U-Net的编码器的问题,这足以满足大多数任务。

如skip connection分析部分所述,作者观察到编码器的特征与解码器的特征不一致,即在某些情况下,由于浅层编码器与解码器之间存在语义差异,语义信息较少的浅层特征可能会通过简单的skip connection损害最终性能。受此启发,作者通过在普通U-Net编码器和解码器之间设计一个通道化的Transformer模块来构建UCTransNet框架,以更好地融合编码器特性,减少语义差距。

具体来说,本文提出了一种通道转换器(Channel Transformer, CTrans)来替代U-Net中的skip connection,它由两个模块组成:用于多尺度编码器特征融合的CCT(Channel-wise Cross Fusion Transformer)和用于解码器特征与增强CCT特征融合的CCA(Channel-wise Cross Attention)。


4CCT


为了解决前面提到的skip connection问题提出了一种新的通道交叉融合Transformer(CCT),利用Transformer的长依赖建模优势融合多尺度编码器特征。CCT模块包括3个步骤:

  • 多尺度特征嵌入
  • 多通道交叉注意力
  • 多层感知器(MLP)

给定4个skip connection层的输出,首先对特征进行token化,将特征reshape为patch大小分别为{}的flattend 2D patch序列,使这些patch可以在4种尺度下映射到编码器特征的相同区域。在这个过程中,保持原来的通道尺寸。然后,连接4个层的Token ; 作为key和作为value。


5Multi-head Cross-Attention


token被输入到多头通道交叉注意力模块,然后是具有残差结构的多层感知器(MLP),以编码通道关系和依赖,使用多尺度特征从每个U-Net编码器级别提取特征。

图5

如图5所示,本文提出的CCT模块包含5个输入,其中4个token 作为query,一个连接token 作为key和value:

image.png

其中;;为不同输入的权值,d为序列长度(patch编号),为4个skip connection层的通道尺寸。在实现中;;;。

与;;,产生相似矩阵, 通过交叉注意力(CA)机制对V进行加权:

image.png

其中和分别表示实例归一化和softmax函数。

image.png

图4

与原始自注意力的主要区别在于,本文沿着通道轴而不是patch轴进行注意力操作(见图4),并且在相似图上使用实例归一化,使得梯度可以平滑地传播。在N头注意力情况下,多头交叉注意力后的输出计算如下:

image.png

N是Head数。下面,应用简单的MLP和残差算子,得到输出如下:

image.png

为简便起见,在方程中省略了层标准化(LN)。将式(4)中的操作重复L次,构建L层变压器。在实现中,N和L都被设置为4。最后,对第l层的4个输出、、和进行上采样重构,再进行卷积层重构,并分别与解码器特征、、和连接。


6CCA


为了更好地融合Channel Transformer与U-Net解码器之间语义不一致的特征,本文提出了一个面向通道的交叉注意力模块,该模块可以指导Channel和information filtering of the Transformer特征,消除与解码器特征的歧义。

数学上,将第级Transformer输出和第级解码器特征映射作为通道交叉注意力的输入。空间压缩由全局平均池化(GAP)层执行,产生向量及其第k个通道。使用这个操作来嵌入全局空间信息,然后生成注意力Mask:

image.png

其中和和为2个线性层和ReLU算子的权重。Eq.(5)中的这个操作对通道依赖进行编码。根据ECA-Net的经验表明,避免降维对学习通道注意力很重要,使用单一线性层和s形函数来构建通道注意力图。结果向量用于重新校准或激发到,其中激活表示每个通道的重要性。最后,将mask的与第i级解码器的上采样特征连接起来。


7实验


表1报告了实验结果,其中最好的结果用粗体表示。表1显示,本文的方法具有一致的改进之前的效果,如Glas数据集,与基于Transformer模型相比性能分别提升2.42%(3.59%),4.05%(7.07%)的Dice (IoU)较U-Net基础模型和从1.80%(2.98%),3.65%(6.12%)。

在表2中,可以做类似的观察和结论,这再次验证了UCTransNet优于其他所有公司。此外,预训练方案不仅收敛速度更快,而且在MoNuSeg数据集上取得了比其他方法更好的性能,甚至优于联合学习方案。这些观察结果表明,这两个提出的模块可以纳入预先训练的U-Net模型,以提高分割性能。

图6图7

对比模型的分割结果图6和图7。红框突出显示UCTransNet比其他方法表现更好的区域。结果表明,UCTransNet可以产生更好的分割结果,与Baseline模型的分割结果相比,UCTransNet的分割结果更接近ground truth。可以看出,提出的方法不仅突出了右侧显著区域,消除了混淆的假阳性病变,而且产生了连贯的边界。这些观察结果表明UCTransNet能够在保留详细形状信息的同时进行更精细的分割。

如表3所示,在所有数据集上,“Base+CCT+CCA”总体上优于其他Baseline。通过将CCT和CCA集成到U-Net在Dice和IoU方面分别提高了1.12%和1.22%,说明了两个模块组合的有效性。研究结果揭示了多尺度多通道特征融合在编码器-解码器框架中对提高分割性能的重要性。


8参考


[1].UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer

相关文章
|
编解码 数据库
详细分析ResNet | 用CarNet教你如何一步一步设计轻量化模型(二)
详细分析ResNet | 用CarNet教你如何一步一步设计轻量化模型(二)
248 0
|
6月前
|
机器学习/深度学习 PyTorch 测试技术
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
|
机器学习/深度学习 PyTorch 算法框架/工具
ResNet代码复现+超详细注释(PyTorch)
ResNet代码复现+超详细注释(PyTorch)
2182 1
|
6月前
|
机器学习/深度学习 PyTorch 语音技术
Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类
深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。 为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。
539 0
|
6月前
|
机器学习/深度学习 数据采集 PyTorch
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
223 1
|
机器学习/深度学习 人工智能 PyTorch
ResNet详解:网络结构解读与PyTorch实现教程
ResNet详解:网络结构解读与PyTorch实现教程
1634 0
|
机器学习/深度学习 人工智能 PyTorch
【图像分类】基于OpenVINO实现PyTorch ResNet50图像分类
【图像分类】基于OpenVINO实现PyTorch ResNet50图像分类
310 0
|
机器学习/深度学习 存储 人工智能
模型推理加速系列 | 03:Pytorch模型量化实践并以ResNet18模型量化为例(附代码)
本文主要简要介绍Pytorch模型量化相关,并以ResNet18模型为例进行量化实践。
|
算法 PyTorch 算法框架/工具
pytorch实现空洞卷积+残差网络实验(torch实现)
pytorch实现空洞卷积+残差网络实验(torch实现)
398 0

热门文章

最新文章