【小样本图像分割-2】UniverSeg: Universal Medical Image Segmentation

简介: UniverSeg是一种用于医学图像分割的小样本学习方法,通过大量医学图像数据集的训练,实现了对未见过的解剖结构和任务的泛化能力。该方法引入了CrossBlock机制,以支持集和查询集之间的特征交互为核心,显著提升了分割精度。实验结果显示,UniverSeg在多种任务上优于现有方法,特别是在任务多样性和支持集多样性方面表现出色。未来,该方法有望扩展到3D模型和多标签分割,进一步提高医学图像处理的灵活性和效率。

【小样本图像分割-2】UniverSeg: Universal Medical Image Segmentation

image-20240814172421551

OK,今天我们来看另外一篇小样本医学图像分割的工作,在这里的这个工作中也使用了类似于元学习的方式,作者搜集了大量的医学图像数据集。在这个医学图像数据集中作者按照任务将数据集分为了支持集和查询集,按照任务的方式对模型进行训练,并在训练过程中采用了大量的数据增强的方式,推理的时候,只需要输入支持集的原始图像和mask图像以及等待预测的图像,就能得到对应的分割结果。

文章的地址:[2304.06131] UniverSeg: Universal Medical Image Segmentation (arxiv.org)

代码的地址:https://github.com/JJGO/UniverSeg

摘要

虽然深度学习模型已经成为医学图像分割的主要方法,但它们通常无法推广到涉及新解剖结构、图像模式或标签的未知分割任务。对于新的分割任务,研究人员通常必须训练或微调模型,这既耗时又给临床研究人员带来了巨大的障碍,因为他们往往缺乏训练神经网络的资源和专业知识。我们提出了UniverSeg,一种无需额外训练即可解决看不见的医学分割任务的方法。给定一个查询图像和图像标签对的示例集来定义一个新的分割任务,UniverSeg采用一种新的CrossBlock机制来生成准确的分割映射,而不需要额外的训练。为了实现对新任务的泛化,我们收集并标准化了53个开放获取的医学分割数据集,其中包含超过22,000次扫描,我们将其称为MegaMedical。我们用这个集合来训练UniverSeg不同的解剖学和成像模式。我们证明了在看不见的任务上,UniverSeg大大优于几种相关的方法,并且对所提议的系统的重要方面进行了彻底的分析和得出见解。

作者提出的方法

首先作者给出了自己的方法和传统方法上的对比,传统的方法需要先在大规模的数据集上进行预训练,然后在小规模的数据集上进行微调,微调的效果实际上也取决于你数据集的实际规模大小。而作者提出的方法则可以在大批量的类似任务上进行学习,类似的任务指的是只需要少量的支持集图像和要预测的图像,进而通过支持集图像的特征推断出待预测图像的分割结果。

image-20240815101815388

作者给出了自己的网络结构,网络结构上和平常的UNET一类的模型相比没有特别大的差异,只是在中间的层中添加了CrossBlock的模块用于支持集特征和查询集特征之间的交互。为了整合跨空间尺度的信息,我们将CrossBlock模块组合在具有残余连接的编码器-解码器结构中,类似于流行的UNet架构(图3)。该网络将查询图像xt和支持集St = {(xt i, yt i)}n i=1的图像和标签-映射对作为输入,每个图像和标签-映射对按通道连接,并输出分割预测图。编码器路径中的每个级别都由CrossBlock组成,然后是查询和支持集表示的空间下采样操作。扩展路径中的每个级别都包括对两个表示进行上采样,使其空间分辨率加倍,将它们与编码路径中同等大小的表示连接起来,然后是CrossBlock。我们执行单个1x1卷积来将最终查询表示映射到预测。任务增强(AugT(x, y, S))。与减少训练样本过拟合的标准数据增强类似,增强训练任务对于泛化到新任务很有用,特别是那些远离训练任务分布的任务。我们引入了任务增强——使用相同类型的任务更改转换来修改所有查询和支持图像,和/或所有分割映射。示例任务增强包括分割映射的边缘检测或对所有图像和标签的水平翻转。我们在补充部分C中提供了所有增强和参数的列表。

image-20240815102150687

训练是以任务为基础的,每个任务由支持集和测试集构成,和平常的深度学习的语义分割任务相比,这样的元学习的结构输入的部分包含了两个,其中一个是输入的支持集的原始图像和标签,这个相当于是数据库,另外一个是需要进行查询的图像,通过这一系列图像的特征交互,让网络自己按照任务进行学习,并可以从中学习出支持集中原始图像和标签图像中的关联,进而可以对语义分割的结果进行推测。其中网络的关键在于图像中对于特征之间的交互,也就是文中提到的CrossBlock模块。

推理的过程中,相当于是一个特殊的单一的任务,通过训练好的模型,已经可以学习到支持集和查询集之间的结果,这个时候查询集的结果取决于支持集中的图像数据量,这是为什么在n-way k-shot问题中,k的数字越大,查询集的预测结果相对也越好。这里作者放了训练的公式,可以进行参考。

image-20240815102826503

这里放一下这个的代码,从这里的代码也可以看出,输入的部分主要是分为了3个,分别是原始的要预测的图像,支持集的图像和支持集图像的标签。

image-20240815105921995

其中作者自己实现了一个交叉的卷积的算子。大概内容是在x和y上进行特征的concat也就是连接。

class CrossConv2d(nn.Conv2d):
    """
    Compute pairwise convolution between all element of x and all elements of y.
    x, y are tensors of size B,_,C,H,W where _ could be different number of elements in x and y
    essentially, we do a meshgrid of the elements to get B,Sx,Sy,C,H,W tensors, and then
    pairwise conv.
    Args:
        x (tensor): B,Sx,Cx,H,W
        y (tensor): B,Sy,Cy,H,W
    Returns:
        tensor: B,Sx,Sy,Cout,H,W
    """
    """
    CrossConv2d is a convolutional layer that performs pairwise convolutions between elements of two input tensors.

    Parameters
    ----------
    in_channels : int or tuple of ints
        Number of channels in the input tensor(s).
        If the tensors have different number of channels, in_channels must be a tuple
    out_channels : int
        Number of output channels.
    kernel_size : int or tuple of ints
        Size of the convolutional kernel.
    stride : int or tuple of ints, optional
        Stride of the convolution. Default is 1.
    padding : int or tuple of ints, optional
        Zero-padding added to both sides of the input. Default is 0.
    dilation : int or tuple of ints, optional
        Spacing between kernel elements. Default is 1.
    groups : int, optional
        Number of blocked connections from input channels to output channels. Default is 1.
    bias : bool, optional
        If True, adds a learnable bias to the output. Default is True.
    padding_mode : str, optional
        Padding mode. Default is "zeros".
    device : str, optional
        Device on which to allocate the tensor. Default is None.
    dtype : torch.dtype, optional
        Data type assigned to the tensor. Default is None.

    Returns
    -------
    torch.Tensor
        Tensor resulting from the pairwise convolution between the elements of x and y.

    Notes
    -----
    x and y are tensors of size (B, Sx, Cx, H, W) and (B, Sy, Cy, H, W), respectively,
    The function does the cartesian product of the elements of x and y to obtain a tensor
    of size (B, Sx, Sy, Cx + Cy, H, W), and then performs the same convolution for all 
    (B, Sx, Sy) in the batch dimension. Runtime and memory are O(Sx * Sy).

    Examples
    --------
    >>> x = torch.randn(2, 3, 4, 32, 32)
    >>> y = torch.randn(2, 5, 6, 32, 32)
    >>> conv = CrossConv2d(in_channels=(4, 6), out_channels=7, kernel_size=3, padding=1)
    >>> output = conv(x, y)
    >>> output.shape  #(2, 3, 5, 7, 32, 32)
    """

    @validate_arguments
    def __init__(
        self,
        in_channels: size2t,
        out_channels: int,
        kernel_size: size2t,
        stride: size2t = 1,
        padding: size2t = 0,
        dilation: size2t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:

        if isinstance(in_channels, (list, tuple)):
            concat_channels = sum(in_channels)
        else:
            concat_channels = 2 * in_channels

        super().__init__(
            in_channels=concat_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
        )

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Compute pairwise convolution between all elements of x and all elements of y.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of size (B, Sx, Cx, H, W).
        y : torch.Tensor
            Input tensor of size (B, Sy, Cy, H, W).

        Returns
        -------
        torch.Tensor
            Tensor resulting from the cross-convolution between the elements of x and y.
            Has size (B, Sx, Sy, Co, H, W), where Co is the number of output channels.
        """
        B, Sx, *_ = x.shape
        _, Sy, *_ = y.shape

        xs = E.repeat(x, "B Sx Cx H W -> B Sx Sy Cx H W", Sy=Sy)
        ys = E.repeat(y, "B Sy Cy H W -> B Sx Sy Cy H W", Sx=Sx)

        xy = torch.cat([xs, ys], dim=3,)

        batched_xy = E.rearrange(xy, "B Sx Sy C2 H W -> (B Sx Sy) C2 H W")
        batched_output = super().forward(batched_xy)

        output = E.rearrange(
            batched_output, "(B Sx Sy) Co H W -> B Sx Sy Co H W", B=B, Sx=Sx, Sy=Sy
        )
        return output

这里也给出了作者提到的关键的cross block的实现。这里的输出和上面的图中保持了一致,一个是分层级输出的support,一个是单独的target。

@validate_arguments_init
@dataclass(eq=False, repr=False)
class CrossBlock(nn.Module):

    in_channels: size2t
    cross_features: int
    conv_features: Optional[int] = None
    cross_kws: Optional[Dict[str, Any]] = None
    conv_kws: Optional[Dict[str, Any]] = None

    def __post_init__(self):
        super().__init__()

        conv_features = self.conv_features or self.cross_features
        cross_kws = self.cross_kws or {
   }
        conv_kws = self.conv_kws or {
   }

        self.cross = CrossOp(self.in_channels, self.cross_features, **cross_kws)
        self.target = Vmap(ConvOp(self.cross_features, conv_features, **conv_kws))
        self.support = Vmap(ConvOp(self.cross_features, conv_features, **conv_kws))

    def forward(self, target, support):
        target, support = self.cross(target, support)
        target = self.target(target)
        support = self.support(support)
        return target, support

本方法的一些效果

下面是作者给出部分图像的结果,除了指标上的结果。从图像的结果来看,我个人认为作者主要是用了自己比较好的结果,因为他对比的方法在我们第一期有做过类似的比较的,其中上一篇论文中展现的模型的能力结果强大。

image-20240815103436765

下面的表格是对比实验中展示的对比结果,这里的结果可以作为后续自己实验的对比结果。(注意这里是ICCV 2023的论文,其中的实验结果还是比较新鲜)

image-20240815103816555

这里的消融实验的对比结果。消融实验的部分主要是为了说明自己数据增强所使用到的策略。我们对我们在培训期间用于增加数据和任务多样性的三种主要技术进行了消融研究:任务内增强、任务增强和合成任务。

image-20240815103916544

结论

我们介绍了UniverSeg,一种用于医学图像分割的单任务不可知论模型的学习方法。我们使用大量不同的开放获取医学分割数据集来训练UniverSeg,它能够泛化到看不见的解剖结构和任务。我们引入了一种新的交叉卷积操作,它可以在不同的尺度上交互查询和支持表示。在我们的实验中,UniverSeg在所有hold out数据集中的表现都明显优于现有的few-shot方法。通过广泛的消融研究,我们得出结论,UniverSeg性能强烈依赖于训练期间的任务多样性和推理期间的支持集多样性。这突出了UniverSeg促进可变大小支持集的实用性,为潜在用户的数据集提供了灵活性。

为了的展望。这里的展望主要是帮助我们做后面的实验使用。在这项工作中,我们重点展示和深入分析了UniverSeg的核心思想,使用二维数据和单个标签。我们对未来使用2.5D或3D模型和多标签地图分割3D体的扩展感到兴奋,并进一步缩小与上界的差距。UniverSeg承诺可以轻松适应科学家和临床研究人员确定的新分割任务,而无需对模型进行再训练,这对他们来说通常是不切实际的

目录
相关文章
|
3天前
|
机器学习/深度学习 编解码 算法
论文精度笔记(二):《Deep Learning based Face Liveness Detection in Videos 》
论文提出了基于深度学习的面部欺骗检测技术,使用LRF-ELM和CNN两种模型,在NUAA和CASIA数据库上进行实验,发现LRF-ELM在检测活体面部方面更为准确。
6 1
论文精度笔记(二):《Deep Learning based Face Liveness Detection in Videos 》
|
2天前
|
机器学习/深度学习 人工智能 文件存储
【小样本图像分割-3】HyperSegNAS: Bridging One-Shot Neural Architecture Search with 3D Medical Image Segmentation using HyperNet
本文介绍了一种名为HyperSegNAS的新方法,该方法结合了一次性神经架构搜索(NAS)与3D医学图像分割,旨在解决传统NAS方法在3D医学图像分割中计算成本高、搜索时间长的问题。HyperSegNAS通过引入HyperNet来优化超级网络的训练,能够在保持高性能的同时,快速找到适合不同计算约束条件的最优网络架构。该方法在医疗分割十项全能(MSD)挑战的多个任务中展现了卓越的性能,特别是在胰腺数据集上的表现尤为突出。
5 0
【小样本图像分割-3】HyperSegNAS: Bridging One-Shot Neural Architecture Search with 3D Medical Image Segmentation using HyperNet
|
2天前
|
机器学习/深度学习 计算机视觉
【小样本图像分割-1】PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment
本文介绍了ICCV 2019的一篇关于小样本图像语义分割的论文《PANet: Few-Shot Image Semantic Segmentation With Prototype Alignment》。PANet通过度量学习方法,从支持集中的少量标注样本中学习类的原型表示,并通过非参数度量学习对查询图像进行分割。该方法在PASCAL-5i数据集上取得了显著的性能提升,1-shot和5-shot设置下的mIoU分别达到48.1%和55.7%。PANet还引入了原型对齐正则化,以提高模型的泛化能力。
7 0
【小样本图像分割-1】PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment
|
机器学习/深度学习 编解码 自然语言处理
Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation论文解读
在过去的几年中,卷积神经网络(CNN)在医学图像分析方面取得了里程碑式的进展。特别是基于U型结构和跳跃连接的深度神经网络在各种医学图像任务中得到了广泛的应用。
666 0
|
机器学习/深度学习 资源调度 数据可视化
【计算机视觉 | 目标检测】Detecting Twenty-thousand Classes using Image-level Supervision
本文提出的方法也采用了经典的两阶段范式,在第一阶段采用直接提取RPN的方法,第二阶段对做细化的具体类别进行assign和识别。
|
机器学习/深度学习 编解码 自然语言处理
FCT: The Fully Convolutional Transformer for Medical Image Segmentation 论文解读
我们提出了一种新的transformer,能够分割不同形态的医学图像。医学图像分析的细粒度特性所带来的挑战意味着transformer对其分析的适应仍处于初级阶段。
231 0
|
机器学习/深度学习 计算机视觉
【计算机视觉 | 目标检测】RegionCLIP: Region-based language-image pretraining
RegionCLIP的目的便是实现从image-text pairs的匹配到region-text pairs的匹配。构建一个模型进行图像区域的推理研究(如目标检测),目的是学习一个包含丰富的对象概念的区域视觉-语义空间,以便它可以用于开放词汇的目标检测。实质上就是训练一个视觉编码器V,使它可以编码图像区域,并将它们与语言编码器L编码的区域描述相匹配。
|
编解码 资源调度 自然语言处理
【计算机视觉】Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP(OVSeg)
基于掩码的开放词汇语义分割。 从效果上来看,OVSeg 可以与 Segment Anything 结合,完成细粒度的开放语言分割。
|
机器学习/深度学习 存储 机器人
LF-YOLO: A Lighter and Faster YOLO for Weld Defect Detection of X-ray Image
高效的特征提取EFE模块作为主干单元,它可以用很少的参数和低计算量提取有意义的特征,有效地学习表征。大大减少了特征提取的消耗
144 0
|
机器学习/深度学习 编解码 数据可视化
图像目标分割_2 FCN(Fully Convolutional Networks for Semantic Segmentation)
图像语义分割:给定一张图片,对图片上每一个像素点进行分类!但是与图像分类目的不同,语义分割模型要具有像素级的密集预测能力才可以。
230 0