【YOLOv8改进】STA(Super Token Attention) 超级令牌注意力机制 (论文笔记+引入代码)

简介: 该专栏探讨YOLO目标检测的创新改进和实战应用,介绍了使用视觉Transformer的新方法。为解决Transformer在浅层处理局部特征时的冗余问题,提出了超级令牌(Super Tokens)和超级令牌注意力(STA)机制,旨在高效建模全局上下文。通过稀疏关联学习和自注意力处理,STA降低了计算复杂度,提升了全局依赖的捕获效率。由此构建的层次化视觉Transformer在ImageNet-1K、COCO检测和ADE20K语义分割任务上展现出优秀性能。此外,文章提供了YOLOv8中实现STA的代码示例。更多详细信息和配置可在相关链接中找到。

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

摘要

视觉Transformer在许多视觉任务上展示了卓越的性能。然而,它在浅层捕获局部特征时可能会面临高度冗余的问题。因此,使用了局部自注意力或早期阶段的卷积来减少这种冗余,但这牺牲了捕获长距离依赖的能力。一个挑战随之而来:在神经网络的早期阶段,我们是否能高效且有效地进行全局上下文建模?为解决这一问题,我们从超像素的设计中获得启示,这种设计通过减少图像基元的数量来简化后续处理,并在视觉Transformer中引入了超级令牌。超级令牌旨在为视觉内容提供有意义的语义分割,这样既减少了自注意力中的令牌数量,也保留了全局建模能力。具体而言,我们提出了一种简单而有效的超级令牌注意力(STA)机制,它包括三个步骤:首先通过稀疏关联学习从视觉令牌中抽取超级令牌,接着对这些超级令牌进行自注意力处理,最后将它们映射回原始的令牌空间。STA通过将普通的全局注意力分解为稀疏关联图与低维度注意力的乘积,极大地提高了捕获全局依赖的效率。基于STA,我们开发了一个层次化的视觉Transformer。广泛的实验显示了它在各种视觉任务上的强大性能。特别是,在没有任何额外训练数据或标签的情况下,它在ImageNet-1K上实现了86.4%的顶级准确率,以及在COCO检测任务上达到53.9的盒AP和46.8的掩码AP,在ADE20K语义分割任务上实现了51.9的mIOU。

创新点

  1. 引入超级标记(super tokens):通过引入超级标记的概念,实现了在视觉Transformer中的全局上下文建模。超级标记将原始标记聚合成具有语义意义的单元,从而减少了自注意力计算的复杂度,提高了全局信息的捕获效率。

  2. Super Token Attention(STA)机制:提出了一种简单而强大的超级标记注意力机制,包括超级标记采样、多头自注意力和标记上采样等步骤。STA通过稀疏映射和自注意力计算,在全局和局部之间实现了高效的信息交互,有效地学习全局表示。

  3. Hierarchical Vision Transformer:设计了一种层次化的视觉Transformer结构,结合了卷积层和超级标记Transformer块,以在不同层次上实现高效和有效的表示学习。这种结构在各种视觉任务上展现了出色的性能,包括图像分类、目标检测和语义分割等。

  4. 性能优越性:在多个视觉任务上进行了广泛的实验验证,包括ImageNet图像分类、COCO目标检测和ADE20K语义分割等。实验结果表明,基于超级标记的视觉Transformer在各项任务上均取得了优异的性能,超越了其他Transformer模型的表现。

yolov8 引入

class StokenAttention(nn.Module):
    def __init__(self, dim, stoken_size=[8,8], n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()

        self.n_iter = n_iter  # 迭代次数
        self.stoken_size = stoken_size  # 空间令牌的大小

        self.scale = dim ** - 0.5  # 缩放因子

        self.unfold = Unfold(3)  # 定义Unfold实例
        self.fold = Fold(3)  # 定义Fold实例

        # 定义空间注意力机制
        self.stoken_refine = StAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)

    def stoken_forward(self, x):
        '''
           x: (B, C, H, W)
        '''
        B, C, H0, W0 = x.shape
        h, w = self.stoken_size

        pad_l = pad_t = 0
        pad_r = (w - W0 % w) % w
        pad_b = (h - H0 % h) % h
        if pad_r > 0 or pad_b > 0:
            x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))  # 对输入张量进行填充

        _, _, H, W = x.shape

        hh, ww = H // h, W // w

        stoken_features = F.adaptive_avg_pool2d(x, (hh, ww))  # 自适应平均池化
        pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)

        with torch.no_grad():
            for idx in range(self.n_iter):
                stoken_features = self.unfold(stoken_features)  # 展开空间令牌特征
                stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
                affinity_matrix = pixel_features @ stoken_features * self.scale  # 计算亲和矩阵
                affinity_matrix = affinity_matrix.softmax(-1)  # 对亲和矩阵进行softmax

                affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
                affinity_matrix_sum = self.fold(affinity_matrix_sum)
                if idx < self.n_iter - 1:
                    stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix
                    stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
                    stoken_features = stoken_features / (affinity_matrix_sum + 1e-12)  # 归一化

        stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix
        stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
        stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12)  # 归一化

        stoken_features = self.stoken_refine(stoken_features)  # 细化空间令牌特征

        stoken_features = self.unfold(stoken_features)  # 展开细化后的特征
        stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
        pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2)  # 计算最终的像素特征

        pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
        if pad_r > 0 or pad_b > 0:
            pixel_features = pixel_features[:, :, :H0, :W0]  # 去除填充部分

        return pixel_features  # 返回最终的像素特征

    def direct_forward(self, x):
        B, C, H, W = x.shape
        stoken_features = x
        stoken_features = self.stoken_refine(stoken_features)
        return stoken_features  # 返回直接计算的空间令牌特征

    def forward(self, x):
        if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
            return self.stoken_forward(x)  # 使用空间令牌前向计算
        else:
            return self.direct_forward(x)  # 直接前向计算

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139113660

相关文章
|
机器学习/深度学习 编解码 文件存储
YOLOv8改进 | 融合改进篇 | BiFPN+ RepViT(教你如何融合改进机制)
YOLOv8改进 | 融合改进篇 | BiFPN+ RepViT(教你如何融合改进机制)
1609 1
|
机器学习/深度学习 编解码 数据可视化
【即插即用】涨点神器AFF:注意力特征融合(已经开源,附论文和源码链接)
【即插即用】涨点神器AFF:注意力特征融合(已经开源,附论文和源码链接)
8312 1
|
机器学习/深度学习 计算机视觉
YOLOv8改进 | Conv篇 | 在线重参数化卷积OREPA助力二次创新(提高推理速度 + FPS)
YOLOv8改进 | Conv篇 | 在线重参数化卷积OREPA助力二次创新(提高推理速度 + FPS)
640 0
|
10月前
|
并行计算 PyTorch Shell
YOLOv11改进策略【Neck】| 有效且轻量的动态上采样算子:DySample
YOLOv11改进策略【Neck】| 有效且轻量的动态上采样算子:DySample
973 11
YOLOv11改进策略【Neck】| 有效且轻量的动态上采样算子:DySample
|
11月前
|
机器学习/深度学习 算法 计算机视觉
YOLOv11改进策略【SPPF】| SimSPPF,简化设计,提高计算效率
YOLOv11改进策略【SPPF】| SimSPPF,简化设计,提高计算效率
2236 8
YOLOv11改进策略【SPPF】| SimSPPF,简化设计,提高计算效率
|
机器学习/深度学习 数据可视化 测试技术
YOLO11实战:新颖的多尺度卷积注意力(MSCA)加在网络不同位置的涨点情况 | 创新点如何在自己数据集上高效涨点,解决不涨点掉点等问题
本文探讨了创新点在自定义数据集上表现不稳定的问题,分析了不同数据集和网络位置对创新效果的影响。通过在YOLO11的不同位置引入MSCAAttention模块,展示了三种不同的改进方案及其效果。实验结果显示,改进方案在mAP50指标上分别提升了至0.788、0.792和0.775。建议多尝试不同配置,找到最适合特定数据集的解决方案。
3316 0
|
10月前
|
计算机视觉
YOLOv11改进策略【卷积层】| RCS-OSA 通道混洗的重参数化卷积 二次创新C3k2
YOLOv11改进策略【卷积层】| RCS-OSA 通道混洗的重参数化卷积 二次创新C3k2
491 0
YOLOv11改进策略【卷积层】| RCS-OSA 通道混洗的重参数化卷积 二次创新C3k2
|
10月前
|
机器学习/深度学习 计算机视觉
YOLOv11改进策略【卷积层】| CVPR-2023 部分卷积 PConv 轻量化卷积,降低内存占用
YOLOv11改进策略【卷积层】| CVPR-2023 部分卷积 PConv 轻量化卷积,降低内存占用
1047 0
YOLOv11改进策略【卷积层】| CVPR-2023 部分卷积 PConv 轻量化卷积,降低内存占用
|
10月前
|
机器学习/深度学习 计算机视觉
YOLOv11改进策略【Head】| AFPN渐进式自适应特征金字塔,增加针对小目标的检测头(附模块详解和完整配置步骤)
YOLOv11改进策略【Head】| AFPN渐进式自适应特征金字塔,增加针对小目标的检测头(附模块详解和完整配置步骤)
1416 12
YOLOv11改进策略【Head】| AFPN渐进式自适应特征金字塔,增加针对小目标的检测头(附模块详解和完整配置步骤)
|
机器学习/深度学习 编解码 PyTorch
CVPR 2023 | 主干网络FasterNet 核心解读 代码分析
本文分享来自CVPR 2023的论文,提出了一种快速的主干网络,名为FasterNet。核心算子是PConv,partial convolution,部分卷积,通过减少冗余计算和内存访问来更有效地提取空间特征。
10434 58