Title: Few Shot Medical Image Segmentation with Cross Attention Transformer
PDF: https://arxiv.org/pdf/2303.13867
Code: coming soon...
导读
在深度学习医学图像分割领域,训练一个性能强,可以大规模部署落地的模型,往往需要大量手动标注的数据进行监督训练,其中花费的成本是非常高的。为了解决这一挑战,少样本学习(few-shot)技术有潜力从有限的几个sample中学习新类别的能力。本文提出了一种基于交叉掩码注意力Transformer
的少样本医学图像分割新框架CAT-Net
:通过挖掘support
和query
图像之间的相关性,并限制模型仅关注有用的前景信息,来提高support
和query
特征的表达能力。同时,本文还进一步设计了一个迭代细化训练框架来优化查询query
图像分割。作者在三个公共数据集(Abd-CT,Abd-MRI和Card-MRI)上验证了所提出方法的有效性。
引言
我们都知道,图像分割在医学影像中是一项非常常见的任务。在工业界,绝大部分医学图像分割任务还是基于全监督进行训练(或许只因为它稳定且精度高,只要公司不断的堆高质量标注数据就可以了,这个成本公司还是愿意花的);但医学图像的标注是一项耗时长,成本高的任务,比如对于一个3D的volume
图像(以CT和MRI居多),图像标注则更具有挑战性:第一,标注者需要浏览每个3D扫描的数百个2D切片进行标注,量非常大;第二,有些标注外行人还真的做不好,这往往需要一些专业医生动用专业知识进行标注;第三,这些专业医生往往还有其它很多事情要做,你让医生去标注大量数据还真得考虑可行性。。。
基于这个出发点(痛点),科研界许多研究学者为了解决手动标注所带来的挑战,开辟了各种研究方向(至于公司愿不愿意去运用此类方向做产品就看情况了):例如自监督学习,半监督学习和弱监督学习。尽管利用未标记或弱标记数据的信息,这些技术仍然需要大量的训练数据,这对于医学领域中仅有有限样本的新类别可能不可行。于是,,,研究学者开辟了few-shot
这么一个研究方向:few-shot
学习范式旨在从少量标记数据(称为support
)中学习模型,然后将其应用于仅有少量标记数据的新任务(称为query
),而无需重新训练。
考虑到人体内的数百个器官和无数的疾病,few-shot
学习为各种医学图像分割任务带来了巨大的潜力,可以在数据高效的情况下轻松地研究新任务。
大多数few-shot
分割方法都在学习如何学习(旨在学习元学习器),根据support
图像及其相应的分割标签的知识预测query
图像的分割,而这里的核心是:如何有效地将知识从support
图像传递到query
图像。现有的少样本分割方法主要集中在以下两个方面:
- 如何学习一个元学习器
- 如何更好地将知识从
support
图像传递到query
图像
尽管基于原型的方法效果已经不错,但它们通常忽略了训练过程中support
和query
特征之间的交互。
因此,本文提出了一种名为CAT-Net
的新型网络结构,其基于交叉注意力Transformer,可以更好地捕捉support
图像和query
图像之间的相关性,促进support
和query
特征之间的相互作用,同时减少无用像素信息,提高特征表达能力和分割性能;此外,本文还提出了一个迭代训练框架,将先前的support
分割结果反馈到注意力Transformer中,以有效增强并细化特征和分割结果。作者在三个公共数据集上验证了CAT-Net
的有效性和性能优越性。
few-shot定义
少样本分割(Few-shot segmentation,FSS
)的目的是通过只有少量标注的样本来分割新类别。在FSS
中,数据集被分为训练集Dtrain
和测试集Dtest
,其中训练集包含基类别Ctrain
,测试集包含新类别Ctest
,且Ctrain
和Ctest
没有交集。为了获得用于FSS
的分割模型,采用了通常使用的episode
训练方法。每个训练 / 测试 实例化一个N-way, K-shot分割学习任务。具体而言: support集包含N个类别的K个样本,而query集包含同一类别的一个样本。FSS
模型通过episode
训练以预测query
图像的新类别。在模型推理测试时,模型直接在Dtest
上进行评估,无需重新训练。
方法
图1. Overview of the CAT-NET
如上图1展示了CAT-Net
网络框架图,主要由三部分组成:
- 带有mask的特征提取
MIFE
子网络,用于提取初始query
和support
特征以及query mask
- 交叉mask注意力
Transformer
模块CMAT
,其中query
和support
特征相互促进,从而提高query
预测的准确性 - 迭代细化框架,顺序应用
CMAT
模块以持续促进分割性能,整个框架以端到端的方式进行训练
Mask Incorporated Feature Extraction
CAT-Net中的Mask Incorporate Feature Extraction (MIFE)子网络。MIFE子网络接收查询和支持图像作为输入,生成它们各自的特征,同时集成支持掩膜。然后,使用一个简单的分类器来预测查询图像的分割结果。具体地,首先使用一个特征提取器网络(即ResNet-50)将查询和支持图像对Iq和Is映射到特征空间中,分别产生查询图像的多层特征图Fq和支持图像的特征图Fs。接下来,将支持掩膜与Fs进行池化,然后将其扩展并与Fq和Fs进行连接。此外,还将一个先验掩膜进一步与查询特征进行连接,通过像素级相似度图来增强查询和支持特征之间的相关性。最后,使用一个简单的分类器来处理查询特征,得到查询掩膜。关于MIFE架构的更多细节可以在补充材料中找到。
Cross Masked Attention Transformer
CMAT
模块包括三个主要组成部分:自注意力模块、交叉掩码注意力模块,和原型分割模块。其中,自注意力模块用于提取查询query
特征和支持support
特征中的全局信息;交叉掩码注意力模块用于在传递前景信息的同时消除冗余的背景信息;原型分割模块用于生成查询图像的最终预测结果。
自注意力模块
自注意力模块首先将查询特征 和支持特征 展平为1D序列,然后输入到两个相同的自注意力模块中。每个自注意力模块由一个多头注意力层和一个多层感知器MLP
层组成。给定一个输入序列 ,MHA
层首先使用不同的权重将序列投影为三个序列, 和 。然后计算注意力矩阵 ,公式为:
其中, 是输入序列的维度。注意力矩阵通过 softmax
函数归一化,并乘以值序列 以获得输出序列 。MLP层是一个简单的 卷积层,将输出序列 映射到与输入序列 相同的维度。最终,将输出序列 添加到输入序列 中,并使用层归一化(LN)对其进行规范化,以获得最终的输出序列 。自注意力对齐编码器的输出特征序列分别表示为 和 ,分别对应于查询和支持特征。
交叉掩码注意力模块
用于将查询特征和支持特征按照它们的前景信息结合起来。在attention矩阵中,通过支持和查询的掩码来限制注意力区域。具体来说,给定查询特征和来自自注意力模块的支持特征,首先使用不同的权重将输入序列投影到三个序列,和中,从而得到、、和、、。以查询特征为例,交叉注意力矩阵通过下面的公式计算得到:
其中,表示查询特征的维度。这里使用的是点积注意力的形式,通过和的点积计算查询和支持之间的相关性。通过来缩放点积,防止在较高维度时点积的大小对注意力分布的影响过大。
原型分割模块
首先,通过一个“masked average pooling”的方法,建立每个类别的原型(prototype),用于表示该类别的特征分布。
其中,是支持集中图像的数量,是一个二进制掩模,表示位置在支持特征中是否属于类别,是支持特征。具体来说,对于每个类别,该原型是在所有支持图像中该类别对应位置的特征平均值,这样可以得到每个类别的原型。
接着使用非参数度量学习方法进行分割。原型网络计算查询特征向量与原型之间的距离。对所有类别应用softmax函数,生成查询分割结果:
其中cos(·)表示余弦距离,α是一个缩放因子,有助于在训练中反向传播梯度,其中α设置为20。
Iterative Refinement framework
该模块的设计目的是优化查询和支持特征以及查询分割掩模。因此可通过迭代优化的思路进行精细化分割,第i次迭代后的结果由以下公式给出:
每个步骤的细分可表示如下:
其中CMA(·)表示自注意力和交叉掩码注意力模块,Proto(·)代表原型分割模块,该公式表示通过多次迭代应用CMA和Proto模块,来获得增强的特征和优化的分割结果。
实验结果
作者将他们的方法与目前在腹部CT、腹部MRI和心脏MRI数据集上表现最优的方法进行了比较,使用了Dice系数作为评估指标。该比较在两种不同的实验设置(I和II)下进行。
在 Abd-CT 和 Abd-MRI数据集上,相比于之前的最先进方法(SOTAs),这个提出的方法能够生成更加准确和详细的分割结果。
验证了网络中各个组件的有效性:S→Q和Q→S表示CAT-Net中用于增强支持或查询特征的一条支路,而S↔Q表示将交叉注意力应用于S和Q。
在不同迭代次数下使用CMAT模块的影响,可以观察到:增加模块数量可以提高性能,在使用5个模块时,Dice系数最大提高了2.26%。考虑到使用4和5个CMAT模块之间的性能提升不显著,因此作者选择在最终模型中使用四个CMAT模块,以在效率和性能之间取得平衡。
结论
本文提出了一种用于few-shot
医学图像分割的交叉注意力Transformer网络CAT-Net
。通过交叉掩码注意力模块实现了查询和支持特征之间的交互,增强了特征表达能力。此外,所提出的CMAT
模块可以通过迭代优化的方式以持续提高分割性能,实验结果表明了每个模块的有效性以及模型相对于SOTA
方法的卓越性能。其中论文中的各个组件属于即插即用模块,可很好的嵌入到few shot
任务中,以提高少样本分割的性能。