新兴的Segment Anything(SAM)在自然图像的零样本分割方面表现出了令人印象深刻的能力。然而,当应用于医学图像时,SAM的性能明显下降。为了使SAM成为计算机视觉社区真正的“基础模型”,找到一种为医学图像数据集定制SAM的有效方法至关重要。
在这项工作中,作者建议冻结SAM编码器并微调轻量级任务特定预测Head,因为SAM中的大多数权重都是由编码器贡献的。此外,SAM是一个可prompt的模型,而prompt不一定在所有应用程序案例中都可用,并且多类分割的精确prompt也很耗时。
因此,作者在这项工作中探索了3种类型的无prompt预测Head,包括ViT、CNN和线性层。对于ViT Head,作者删除SAM的Mask解码器中的prompt Token ,该解码器被命名为AutoSAM。AutoSAM还可以在修改后通过一个单一的推理为不同的类生成Mask。
为了评估作者的微调方法的标签效率,作者在有限的标签数据的公共医学图像分割数据集上比较了这3个预测Head的结果。
实验表明,即使只有一个标记的volumes,微调SAM也能显著提高其在医学图像数据集上的性能。此外,在缺乏标注的情况下,AutoSAM和CNN预测Head也比从Head开始训练和自监督学习方法具有更好的分割精度。
1、简介
生成预训练Transformer(Generative Pre-trained Transformer,GPT)系列模型的成功表明,如果在大规模数据上进行训练,大型语言模型在零样本和非可视域中的少量快照任务上的性能与最新技术相当。
受GPT的启发,Segment Anything(SAM)为图像分割任务引入了一个“基础模型”。他们收集了1100万张图像,并设计了一个半自动数据引擎,平均每张图像产生约100个Mask,从而总共产生10亿个Mask。然后,SAM在该SAM-1B数据集上使用Vision Transformer(ViT)Backbone训练一个大型可prompt模型。在23多个数据集上使用各种零样本任务进行评估后,SAM显示出对大多数自然图像的推广前景。
然而,随着SAM在医学图像领域引起人们的关注,可以观察到SAM在零样本设置下不能很好地推广到医学图像。将用自然图像训练的模型转换为医学图像的挑战可归因于2个主要因素:
- 外观上的巨大差异:自然图像和医学图像在颜色、亮度和对比度方面表现出显著差异。由于所使用的成像模式,例如CT扫描、MRI或超声波,医学图像通常具有不同的特征;
- 目标物体的模糊边界:医学图像经常显示不同组织和器官之间的模糊边界。受过训练的医学专家对解剖结构有必要的了解,并且能够识别出对于仅根据自然图像训练的模型来说可能不明显的细微边界。
考虑到收集与SAM-1B大小相当的医学分割数据集的困难,探索预训练的SAM中是否有可用于医学图像分割的知识是至关重要的。
此外,基于prompt的分割可能不太适合真实世界的应用场景,原因如下:
- 为多类提供prompt很耗时。对于大多数公共医学图像分割的挑战,它总是需要同时分割多个类别。为每个类别输入准确的prompt可能会变得麻烦,尤其是当器官和组织很小并且彼此相邻时;
- 分割性能在很大程度上取决于prompt质量。制作精确的prompt需要特定领域的专家知识,而这并非适用于所有情况。
考虑到这些限制,本文提出了一种在医学图像数据集上微调SAM的直接方法,即冻结SAM编码器的权重,并在其上添加预测Head进行训练。冻结权重的原因是SAM是一个大模型,并且大多数权重由编码器贡献。根据实验结果,由于硬件要求高,对编码器和解码器进行微调不仅对所有开发人员来说不太容易,而且还会导致较差的分割性能。
另一方面,为了提高SAM在临床应用中的可行性,作者将SAM中的Mask解码器替换为不需要prompt进行训练和推理的预测Head。本文评估了三种不同类型的预测Head,包括视觉Transformer(ViT)、卷积神经网络(CNN)和线性层。ViT预测Head采用SAM Mask解码器,命名为AutoSAM,由轻量级交叉注意力模块和转置卷积层组成。作者移除prompt标记并复制图像嵌入以及其他辅助嵌入,以便解码器可以同时为不同的类生成多个Mask。
为了展示作者方法的标记效率,作者在Few-Shot Head学习环境中进行了实验,其中仅使用1或5次标记的MRI扫描来微调模型。在公开可用的医学图像分割数据集上获得的结果突出表明,与零样本即时驱动SAM相比,定制预训练SAM取得了显著改进。
此外,作者的方法在很大程度上优于从Head开始的训练和最先进的自监督学习方法,强调SAM在医学领域的应用潜力。
2、相关工作
2.1、大语言模型
在大型语言模型(LLM)出现之后,一些工作致力于在LLM中引入图像来完成多模态任务。例如,CLIP和ALIGN利用对比学习在嵌入空间中对齐网络图像及其标题。他们发现这个简单的预训练任务可以很好地推广到其他零样本下游任务,如视频中的目标分类和动作识别。
此外,DALL-E通过一个用于生成零样本文本到图像的尺度自回归变换实现了很好的泛化。然而,这些大规模的视觉模型未能解决广泛的所有计算机视觉任务,如图像分割。对于大型图像分割模型来说,获取标签Mask的难度是关键。
SAM(Segment Anything)是第一个开发可prompt的分割模型并自行在广泛的数据集上对其进行预训练的工作。给定适当的prompt,SAM能够在没有特定任务训练的情况下为目标生成可能的Mask。另一方面,DINOv2根据数据和模型大小对ViT模型的预训练进行缩放,以产生通用的视觉特征,利用这些特征可以更容易地微调下游任务。
2.2、为医学图像定制大模型
这一系列工作主要集中在针对特定分割数据集微调SAM,因为SAM在医学图像上表现出显著的性能退化。MedSAM通过30多个医学图像数据集上的标签Mask生成的prompt,对SAM解码器进行了微调,结果表明,与使用prompt生成的零样本预测相比,性能得到了改进。张凯东等人将基于低秩的微调策略应用于SAN编码器,并将其与SAM解码器一起训练,以定制SAM以执行腹部分割任务。吴俊德等人冻结SAM模型的权重,并在SAM中添加可训练的自适应模块,以降低重新训练的成本。
3、本文方法
3.1、背景
首先,作者将简要介绍SAM模型作为背景知识。SAM中有3个主要组件,
- 图像编码器
- prompt编码器
- Mask解码器
图像编码器具有与视觉Transformer(ViT)相同的架构,并在其自己收集的SAM-1B数据集上使用MAE[10]进行预训练。它们提供了三种不同比例的图像编码器ViT-H、ViT-l和ViT-V的权重,作为实时性能和准确性之间权衡的选项。图像编码器获取任何大小的输入图像,并将其整形为1024×1024。然后将图像转换为具有patch大小16×16和嵌入大小256的顺序patch嵌入。经过几个具有窗口注意和残差传播的Transformer块之后,图像编码器的输出具有(64×64,256)的维度。
prompt编码器同时支持稀疏prompt(点、框、文本)和密集prompt(Mask)。稀疏prompt被投影到prompt Token 中并与图像嵌入连接,而密集prompt则使用卷积嵌入并与图像植入逐元素求和。
Mask解码器首先在输出 Token 、prompt Token 和图像嵌入上应用双向注意力模块。然后通过两个转置卷积层对图像嵌入进行上采样,并对放大后的图像嵌入与输出 Token 之间的逐点乘积进行预测。
3.2、Prediction Head
为了以有效的方式使SAM适应特定的医学图像数据集,作者在SAM编码器中保留权重,并附加一个额外的特定任务预测Head进行微调。此外,作者将预测Head设计为不可prompt的,并且唯一的输入是来自SAM编码器的图像嵌入。作者探讨了3种最常见的体系结构类型,ViT、CNN和线性层。
1、Vision Transformer
作者注意到SAM中的原始Mask解码器具有ViT Backbone,因此作者可以对其进行轻微修改,以便预测Head不仅不可prompt,而且能够利用SAM Mask解码器中的权重。
如图2所示,对于SAM解码器,除了prompt Token 和图像嵌入之外,还有可训练的输出 Token ,包括用于生成Mask的Mask Token 和用于预测Mask置信度的IoU Token 。
此外,Mask Token 包括前景Mask Token 和背景Mask Token 。输出 Token 与prompt Token 连接,作者将其命名为辅助嵌入。在双向注意力模块中,每一层都进行自注意力和交叉注意力。关于交叉注意力,它包括从 Token 到图像嵌入,以及从图像嵌入到 Token (作为密钥和值)。然后,通过2个转置的conv层对图像嵌入进行放大,并选择前景Mask Token 与放大的嵌入进行逐点乘积以获得Mask。
相比之下,AutoSAM删除辅助嵌入中的prompt标记,使其不再是可prompt的模型。另一种修改是通过类的数量复制辅助嵌入和图像嵌入,以生成多个类的Mask。每对的计算可以并行进行,因此与生成额外Mask相关的开销是可以忽略的。为一个推理生成多个Mask的替代方法是简单地在输出 Token 中添加更多前景Mask Token 。然而,作者选择第一种策略是因为,直观地说,一组辅助嵌入表示SAM中要分割的一个目标。AutoSAM独立地为每个类启用生成Mask。
2、Convolutional Neural Network
这种类型的预测Head是许多流行的医学图像分割模型中解码器的表示,如UNet、UNet++、TransUNet和Swin-UNetr。作者首先将嵌入的图像Reshape为大小为(256,64,64)的特征图。根据UNet中的结构,CNN Head部有k个阶段(k>=2),每个阶段由Stride为1的conv层和Stride为2的转置conv层组成。
在实验部分尝试了不同的k值,当k>2时,在k−2阶段,转置的conv层被替换为conv层,使得输出特征图总是放大4x。最后,应用kernel-size为1的逐点conv层来生成每个类的预测Mask。
3、Linear Layer
简单的分类Head总是用于评估在预训练任务中学习的特征表示的泛化。在这项工作中,作者还应用线性Head来测试是否存在SAM编码器提取的高级语义信息。与CNN相同,作者将嵌入的图像重新映射为2D特征图,然后直接部署2个转置conv层。然后,作者使用2个kernel-size为1的conv层来代替MLP来获得每个像素的分类。
4、实验
4.1、Dataset
ACDC(自动心脏诊断挑战)数据集是MICCAI 2017挑战的一部分,该挑战包含100名患者的心脏结构的MRI扫描,每个患者有2个3Dvolumes。该数据集还提供了左心室、右心室和心肌的专家分割Mask。
作者根据患者将MRI扫描随机分为三部分,训练集、验证集和测试集,比例为70:15:15。对于预处理,作者对每个volumes进行归一化,以便volumes中的所有像素都是零均值和单位方差。然后,作者将像素值转换为RGB格式,并将volumes内的每个切片存储为PNG文件,因为SAM是在RGB图像上训练的,作者的目标是保持输入格式的一致性。在此之前,尽管MRI扫描是以3Dvolumes进行的,但分割是在2D图像上进行的。
作者计算测试集中每个volumes的Dice分数和平均对称表面距离(ASSD),然后重新生成分割并重复实验。报告了4次的平均得分和标准差。
4.2、训练细节
训练的实施基于深度学习包PyTorch。使用的GPU设备是NVIDIA特斯拉V100,内存为16GB,比A100更容易访问。相比之下,SAM将训练分布在256个A100 GPU中。在训练过程中,作者对输入图像随机应用数据增强,包括高斯噪声、亮度修改、弹性变形和旋转。训练损失是交叉熵损失和Dice Loss的组合。用于更新的优化器算法基于Adam。学习率设置为0.0005,其中ββ。对于所有3个预测Head,单个GPU的最大batch-size为4。默认的训练Epoch是120,因为作者观察到在该Epoch数量之后验证集上的损失收敛。
4.3、Baselines
为了验证作者提出的方法的有效性,作者在相同的设置下对一些基线方法进行了实验作为比较。第一种是从Head开始训练UNet,这是获得特定数据集的自动分割模型的最常见方法。其次,作者还尝试了一种自监督学习方法SimCLR,该方法被广泛用于医学图像领域的标签高效分割。
该SimCLR基线包括两个阶段,预训练和微调。
在训练阶段,作者使用训练集中的所有数据,而不使用任何标注信息。作者从输入图像中获得两个随机视图,并使用UNet编码器将它们投影到特征空间中。然后应用对比损失来最大化两个视图的嵌入之间的一致性。
在微调过程中,UNet的编码器用预先训练的权重进行初始化,并且模型中的所有参数都在标记数据上进行训练。最后,作者在没有任何微调的情况下尝试原始SAM,以解决将SAM自定义到特定数据集的必要性。关于prompt,作者使用box-style的prompt,并且box坐标是基于GT Mask计算的。
4.4、实验结果
1、Label-efficient Adaptation
当在新的数据集上微调模型时,为了降低标注成本,希望微调仅在有限的标注图像的情况下实现有希望的结果。因此,在表1中,作者只提供了1或5个标记的volumes来评估作者方法的数据效率。以下是从表1中得出的主要观察结果。
- 首先,对于这两种设置,AutoSAM和CNN Head显示出与所有其他方法相比最好的分割精度。特别是当只使用1个标记时,AutoSAM的平均 Dice 分数为39.32,几乎是UNet和SimCLR的两倍。这提供了令人信服的证据,证明在SAM编码器中学习到的特征足够通用,可以转移到医学图像中。
就统计显著性而言,很难说AutoSAM或CNN是否具有更高的 Dice 分数,为什么这也意味着SAM的强大威力主要是由图像编码器而不是Mask解码器提取的代表性特征的结果。此外,作者观察到AutoSAM与CNN Head部相比具有更低的ASSD。这种差异可能归因于SAM解码器的训练,该解码器旨在生成集中在prompt位置附近的目标的Mask。相比之下,CNN Head部没有从SAM解码器加载信息,导致ASSD值更高。 - 其次,与AutoSAM和CNN编码器相比,即使仅用1个volumes训练,SAM也表现出更差的分割性能,这有力地支持了微调SAM是解决其在医学图像数据集上性能下降的有效方法。然而,也注意到,SAM的ASSD比其他方法低得多。这一观察结果有助于SAM受益于嵌入框prompt中的局部信息。该定位信息迫使预测Mask位于框区域周围。另一方面,SAM的LV Dice 分数始终为0。根据图4,作者可以发现Myo是一个由其他两个类包围的细圆,边界也很模糊。由于Myo的框接近RB的框,因此Myo实际上被误认为是RV的一部分,因此所有LV区域都被预测为Myo。
- 如表1所示,线性预测Head具有比其他两个预测Head差得多的性能。特别是,当标记数据的数量从1个增加到5个时,线性Head不能获得很大的分割精度提高。作者认为,这一结果是由于极轻的架构。当SAM编码器产生的视觉特征不具有丰富的医学图像语义信息时,这种简单的预测Head会导致模型能力较弱,并可能出现不足。
2、Ablation Study
作者进行的第一项消融研究是关于CNN预测Head中的深度数量如何影响微调结果。在表2中, Dice 随着深度的增加而增加,直到 Depth=4为止。如上所述,线性预测Head可能会出现装配不足的问题。当Depth< 4时,更大的预测Head会带来更好的模型能力。然而,当Depth > 4时,从增加预测Head中的参数所获得的好处开始减少。在这一点上,图像嵌入或预测Head架构的质量成为决定性能的更关键的因素。
作者还评估了AutoSAM和Encoder+CNN在SAM提供的不同编码器尺寸(即ViT-b、ViT-l和ViT-h)下的性能。
表3显示,通常较大的模型大小会在下游任务上产生更好的微调结果,但AutoSAM对编码器架构的敏感性不如Encoder+CNN。当使用ViT-h Backbone时,CNNHead部的 Dice 得分明显高于AutoSAM,尽管它仍然有更高的ASSD。表3也可以作为关于效率和性能之间切换的参考,因为与ViT-b相比,ViT-h导致更长的微调时间和更高的推理延迟。
最后,作者在图5中绘制了使用更多标记数据进行微调的结果。作者发现,当标记的卷数小于10时,AutoSAM仅比UNet(没有额外信息)和SimCLR(在同一数据集上预训练的知识)具有优势。这是因为SAM是在大规模图像数据集上预训练的,并且图像编码器能够提取语义信息,这有利于下游的分割任务。
然而,由于SAM从未接触过医学图像,因此这种语义信息可能是有偏见的,并且特定于自然图像。似乎有了足够的标注数据,从自然图像中获得的知识在将预测Head专门用于医学图像领域时会产生负面影响。因此,为了为所有图像模态建立一个真正的“基础模型”,未来需要一个大规模的医学图像数据集来预训练SAM。
5、总结
尽管SAM在自然图像中取得了成功,但如何有效地将SAM适应分布外的医学图像数据集仍然是一个悬而未决的问题。与现有工作不同,本文为解决这一问题提供了一个新的视角,即冻结SAM图像编码器中的权重,并添加一个轻量级的任务专用预测Head。
为了促进广泛的应用,作者将SAM修改为不可prompt的,并能够生成多类Mask。作者探索了三种类型的预测Head,ViT(称为AutoSAM)、CNN和线性层,其中AutoSAM和CNN Head在Few-Shot Head学习设置中显示出有希望的结果。仅用一个标记进行微调比框prompt的SAM具有更好的性能,这一事实证明了为新数据集定制SAM的必要性。由于标注的数量有限,作者的方法优于从Head开始训练和自监督学习基线。
6、参考
[1].How to Efficiently Adapt Large Segmentation Model(SAM) to Medical Images.