【YOLOv8改进 - Backbone主干】清华大学CloFormer AttnConv :利用共享权重和上下文感知权重增强局部感知,注意力机制与卷积的完美融合

简介: 【YOLOv8改进 - Backbone主干】清华大学CloFormer AttnConv :利用共享权重和上下文感知权重增强局部感知,注意力机制与卷积的完美融合

YOLOv8目标检测创新改进与实战案例专栏

专栏目录: YOLOv8有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLOv8基础解析+创新改进+实战案例

介绍

image-20240620093857632

摘要

视觉变换器(Vision Transformers,ViTs)已被证明在各种视觉任务中具有高效性。然而,将其缩小到移动设备友好的尺寸会导致性能显著下降。因此,开发轻量级视觉变换器成为了一个重要的研究方向。本文介绍了CloFormer,这是一种利用上下文感知局部增强的轻量级视觉变换器。CloFormer探讨了在传统卷积操作中常用的全局共享权重与在注意力机制中出现的特定于token的上下文感知权重之间的关系,并提出了一种高效且简单的模块来捕获高频局部信息。在CloFormer中,我们引入了AttnConv,一种在注意力风格下的卷积操作。提出的AttnConv使用共享权重来聚合局部信息,并部署精心设计的上下文感知权重来增强局部特征。AttnConv与使用池化来减少CloFormer中FLOPs的传统注意力相结合,使模型能够感知高频和低频信息。在图像分类、目标检测和语义分割中的大量实验表明了CloFormer的优越性。

文章链接

论文地址:论文地址

代码地址:代码地址

基本原理

AttnConv是CloFormer中引入的一种卷积操作符,它采用了注意力机制的风格。所提出的 AttnConv 有效地融合了共享权重和上下文感知权重,以聚合高频的局部信息。具体地,AttnConv 首先使用深度卷积(DWconv)提取局部表示,其中 DWconv 具有共享权重。然后,其使用上下文感知权重来增强局部特征。与 Non-Local 等生成上下文感知权重的方法不同,AttnConv 使用门控机制生成上下文感知权重,引入了比常用的注意力机制更强的非线性。此外,AttnConv 将卷积算子应用于 Query 和 Key 以聚合局部信息,然后计算 Q 和 K 的哈达玛积,并对结果进行一系列线性或非线性变换,生成范围在 [-1,1] 之间的上下文感知权重。值得注意的是,AttnConv 继承了卷积的平移等变性,因为它的所有操作都基于卷积。具体公式如下:

image-20240620094510523

最后,将全局特征和局部特征合并起来,并使用一个MLP得到最终的输出。公式表示如下:

image-20240620094526335

  1. AttnConv的操作流程

    • 首先,通过线性变换得到Q、K和V,这一步与标准的注意力机制相同。

    • 接着,在V上使用共享权重(DWconv)进行信息聚合,利用深度卷积提取局部表示。

    • 然后,利用比传统注意力机制更强的非线性方法生成上下文感知权重,用这些权重增强局部特征。

      image-20240620094202359

  2. AttnConv的技术原理

    • AttnConv通过引入共享权重和上下文感知权重来提取高频局部信息。
    • 使用深度卷积(DWconv)提取局部表示,再利用上下文感知权重增强局部特征。
    • 与以往通过局部自注意力生成上下文感知权重的方法不同,AttnConv通过门控机制生成上下文感知权重,引入了比常用注意力机制更强的非线性。
    • AttnConv对Q和K应用卷积操作来聚合局部信息,然后计算Q和K的哈达玛积,并对结果进行一系列线性或非线性变换,生成范围在[-1, 1]内的上下文感知权重。
    • AttnConv保留了卷积的平移等变性特性,因为其所有操作都基于卷积。

核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from typing import List
from efficientnet_pytorch.model import MemoryEfficientSwish

class AttnMap(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.act_block = nn.Sequential(
                            nn.Conv2d(dim, dim, 1, 1, 0),
                            MemoryEfficientSwish(),
                            nn.Conv2d(dim, dim, 1, 1, 0)
                            #nn.Identity()
                         )
    def forward(self, x):
        return self.act_block(x)

class EfficientAttention(nn.Module):

    def __init__(self, dim, num_heads, group_split: List[int], kernel_sizes: List[int], window_size=7, 
                 attn_drop=0., proj_drop=0., qkv_bias=True):
        super().__init__()
        assert sum(group_split) == num_heads
        assert len(kernel_sizes) + 1 == len(group_split)
        self.dim = dim
        self.num_heads = num_heads
        self.dim_head = dim // num_heads
        self.scalor = self.dim_head ** -0.5
        self.kernel_sizes = kernel_sizes
        self.window_size = window_size
        self.group_split = group_split
        convs = []
        act_blocks = []
        qkvs = []
        #projs = []
        for i in range(len(kernel_sizes)):
            kernel_size = kernel_sizes[i]
            group_head = group_split[i]
            if group_head == 0:
                continue
            convs.append(nn.Conv2d(3*self.dim_head*group_head, 3*self.dim_head*group_head, kernel_size,
                         1, kernel_size//2, groups=3*self.dim_head*group_head))
            act_blocks.append(AttnMap(self.dim_head*group_head))
            qkvs.append(nn.Conv2d(dim, 3*group_head*self.dim_head, 1, 1, 0, bias=qkv_bias))
            #projs.append(nn.Linear(group_head*self.dim_head, group_head*self.dim_head, bias=qkv_bias))
        if group_split[-1] != 0:
            self.global_q = nn.Conv2d(dim, group_split[-1]*self.dim_head, 1, 1, 0, bias=qkv_bias)
            self.global_kv = nn.Conv2d(dim, group_split[-1]*self.dim_head*2, 1, 1, 0, bias=qkv_bias)
            #self.global_proj = nn.Linear(group_split[-1]*self.dim_head, group_split[-1]*self.dim_head, bias=qkv_bias)
            self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size!=1 else nn.Identity()

        self.convs = nn.ModuleList(convs)
        self.act_blocks = nn.ModuleList(act_blocks)
        self.qkvs = nn.ModuleList(qkvs)
        self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):
        '''
        x: (b c h w)
        '''
        b, c, h, w = x.size()
        qkv = to_qkv(x) #(b (3 m d) h w)
        qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous() #(3 b (m d) h w)
        q, k, v = qkv #(b (m d) h w)
        attn = attn_block(q.mul(k)).mul(self.scalor)
        attn = self.attn_drop(torch.tanh(attn))
        res = attn.mul(v) #(b (m d) h w)
        return res

    def low_fre_attention(self, x : torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):
        '''
        x: (b c h w)
        '''
        b, c, h, w = x.size()

        q = to_q(x).reshape(b, -1, self.dim_head, h*w).transpose(-1, -2).contiguous() #(b m (h w) d)
        kv = avgpool(x) #(b c h w)
        kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h*w)//(self.window_size**2)).permute(1, 0, 2, 4, 3).contiguous() #(2 b m (H W) d)
        k, v = kv #(b m (H W) d)
        attn = self.scalor * q @ k.transpose(-1, -2) #(b m (h w) (H W))
        attn = self.attn_drop(attn.softmax(dim=-1))
        res = attn @ v #(b m (h w) d)
        res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()
        return res

    def forward(self, x: torch.Tensor):
        '''
        x: (b c h w)
        '''
        res = []
        for i in range(len(self.kernel_sizes)):
            if self.group_split[i] == 0:
                continue
            res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))
        if self.group_split[-1] != 0:
            res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))
        return self.proj_drop(self.proj(torch.cat(res, dim=1)))

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139824105

相关文章
|
6月前
|
机器学习/深度学习
YOLOv8改进 | 2023注意力篇 | MLCA混合局部通道注意力(轻量化注意力机制)
YOLOv8改进 | 2023注意力篇 | MLCA混合局部通道注意力(轻量化注意力机制)
365 1
|
6月前
|
机器学习/深度学习
YOLOv5改进 | 2023注意力篇 | MLCA混合局部通道注意力(轻量化注意力机制)
YOLOv5改进 | 2023注意力篇 | MLCA混合局部通道注意力(轻量化注意力机制)
409 0
|
6月前
|
机器学习/深度学习 Ruby
YOLOv5改进 | 2023注意力篇 | iRMB倒置残差块注意力机制(轻量化注意力机制)
YOLOv5改进 | 2023注意力篇 | iRMB倒置残差块注意力机制(轻量化注意力机制)
348 0
|
6月前
|
机器学习/深度学习 Ruby
YOLOv8改进 | 2023注意力篇 | iRMB倒置残差块注意力机制(轻量化注意力机制)
YOLOv8改进 | 2023注意力篇 | iRMB倒置残差块注意力机制(轻量化注意力机制)
693 0
|
4月前
|
机器学习/深度学习 数据采集 自然语言处理
注意力机制中三种掩码技术详解和Pytorch实现
**注意力机制中的掩码在深度学习中至关重要,如Transformer模型所用。掩码类型包括:填充掩码(忽略填充数据)、序列掩码(控制信息流)和前瞻掩码(自回归模型防止窥视未来信息)。通过创建不同掩码,如上三角矩阵,模型能正确处理变长序列并保持序列依赖性。在注意力计算中,掩码修改得分,确保模型学习的有效性。这些技术在现代NLP和序列任务中是核心组件。**
181 12
|
4月前
|
机器学习/深度学习 计算机视觉
【YOLOv8改进 - 注意力机制】SENetV2: 用于通道和全局表示的聚合稠密层,结合SE模块和密集层来增强特征表示
【YOLOv8改进 - 注意力机制】SENetV2: 用于通道和全局表示的聚合稠密层,结合SE模块和密集层来增强特征表示
|
4月前
|
机器学习/深度学习 图计算 计算机视觉
【YOLOv8改进 - 注意力机制】 CascadedGroupAttention:级联组注意力,增强视觉Transformer中多头自注意力机制的效率和有效性
YOLO目标检测专栏探讨了Transformer在视觉任务中的效能与计算成本问题,提出EfficientViT,一种兼顾速度和准确性的模型。EfficientViT通过创新的Cascaded Group Attention(CGA)模块减少冗余,提高多样性,节省计算资源。在保持高精度的同时,与MobileNetV3-Large相比,EfficientViT在速度上有显著提升。论文和代码已公开。CGA通过特征分割和级联头部增加注意力多样性和模型容量,降低了计算负担。核心代码展示了CGA模块的实现。
|
4月前
|
机器学习/深度学习 计算机视觉
【YOLOv10改进-注意力机制】 MSDA:多尺度空洞注意力 (论文笔记+引入代码)
YOLO目标检测专栏探讨了ViT的改进,提出DilateFormer,它结合多尺度扩张注意力(MSDA)来平衡计算效率和关注域大小。MSDA利用局部稀疏交互减少冗余,通过不同头部的扩张率捕获多尺度特征。DilateFormer在保持高性能的同时,计算成本降低70%,在ImageNet-1K、COCO和ADE20K任务上取得领先结果。YOLOv8引入了MultiDilatelocalAttention模块,用于实现膨胀注意力。更多详情及配置见相关链接。
|
5月前
|
机器学习/深度学习 关系型数据库
【YOLOv8改进 - 注意力机制】NAM:基于归一化的注意力模块,将权重稀疏惩罚应用于注意力机制中,提高效率性能
**NAM: 提升模型效率的新颖归一化注意力模块,抑制非显著权重,结合通道和空间注意力,通过批量归一化衡量重要性。在Resnet和Mobilenet上的实验显示优于其他三种机制。源码见[GitHub](https://github.com/Christian-lyc/NAM)。**
|
5月前
|
机器学习/深度学习 计算机视觉
YOLOv8改进 | 注意力机制 | 添加混合局部通道注意力——MLCA【原理讲解】
YOLOv8专栏介绍了混合局部通道注意力(MLCA)模块,它结合通道、空间和局部信息,提升目标检测性能,同时保持低复杂度。文章提供MLCA原理、代码实现及如何将其集成到YOLOv8中,助力读者实战深度学习目标检测。[YOLOv8改进——更新各种有效涨点方法](https://blog.csdn.net/m0_67647321/category_12548649.html)