【YOLOv8改进】BRA(bi-level routing attention ):双层路由注意力(论文笔记+引入代码)

简介: **BiFormer和HCANet摘要**BiFormer是CVPR2023提出的一种新型视觉Transformer,采用双层路由注意力机制实现动态稀疏注意力,优化计算效率和内存使用,适用于图像分类、目标检测和语义分割任务。代码可在GitHub获取。另一方面,HCANet是针对高光谱图像去噪的深度学习模型,融合CNN和Transformer,强化全局和局部特征建模,通过多尺度前馈网络提升去噪效果。HCANet在HSI数据集上表现优秀,其代码同样开放源代码。

摘要

作为视觉Transformers的核心构建模块,注意力机制是一种强大的工具,用于捕捉长程依赖关系。然而,这种强大功能也带来了代价:计算代价巨大且内存占用高,因为需要计算所有空间位置上成对的token交互。为缓解这一问题,一系列研究尝试通过引入手工设计且内容无关的稀疏性来改进注意力机制,例如将注意力操作限制在局部窗口、轴向条带或膨胀窗口内。与这些方法不同,我们提出了一种新颖的动态稀疏注意力机制,通过双层路由实现更加灵活且具有内容感知的计算分配。具体而言,对于一个查询,首先在粗略的区域级别过滤掉无关的键值对,然后在剩余候选区域(即路由区域)的联合中应用细粒度的token-to-token注意力。我们提供了一个简单而有效的双层路由注意力的实现,该实现利用稀疏性来节省计算和内存,同时仅涉及GPU友好的稠密矩阵乘法。基于所提出的双层路由注意力,我们提出了一种新的通用视觉Transformer,命名为BiFormer。BiFormer在查询自适应的方式下关注一小部分相关token,而不受其他无关token的干扰,因而在密集预测任务中享有良好的性能和高计算效率。在图像分类、目标检测和语义分割等多个计算机视觉任务中的实验证明了我们设计的有效性。代码可在https://github.com/rayleizhu/BiFormer获得。

摘要

摘要——高光谱图像(HSI)去噪对于高光谱数据的有效分析和解释至关重要。然而,同时建模全局和局部特征以增强HSI去噪的研究却很少。在本文中,我们提出了一种混合卷积和注意力网络(HCANet),该网络结合了卷积神经网络(CNN)和Transformers的优势。为了增强全局和局部特征的建模,我们设计了一个卷积和注意力融合模块,旨在捕捉长距离依赖关系和邻域光谱相关性。此外,为了改进多尺度信息聚合,我们设计了一个多尺度前馈网络,通过在不同尺度上提取特征来增强去噪性能。在主流HSI数据集上的实验结果表明,所提出的HCANet具有合理性和有效性。所提出的模型在去除各种复杂噪声方面表现出色。我们的代码可在https://github.com/summitgao/HCANet获得。

文章链接

论文地址:论文地址

代码地址:代码地址

参考代码:代码地址

基本原理

Bi-Level Routing Attention (BRA)是一种注意力机制,旨在解决多头自注意力机制(MHSA)的可扩展性问题。传统的注意力机制要求每个查询都要关注所有的键-值对,这在处理大规模数据时可能会导致计算和存储资源的浪费。BRA通过引入动态的、查询感知的稀疏注意力机制来解决这一问题。

BRA的关键思想是在粗粒度的区域级别上过滤出大部分不相关的键-值对,只保留少量的路由区域。然后,在这些路由区域的并集上应用细粒度的令牌-令牌注意力。这种方法使得每个查询只需关注少量相关的键-值对,从而提高了计算效率和内存利用率。

具体来说,BRA的实现包括以下步骤:

  1. 构建和修剪区域级别的有向图,以过滤出大部分不相关的键-值对。
  2. 在路由区域的并集上应用细粒度的令牌-令牌注意力,以实现动态的、查询感知的稀疏性。

yolov8 代码引入

 class BiLevelRoutingAttention(nn.Module):
    """
    n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
    kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
    topk: topk for window filtering
    param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
    param_routing: extra linear for routing
    diff_routing: wether to set routing differentiable
    soft_routing: wether to multiply soft routing weights 
    """
    def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
                 auto_pad=False):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5


        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)

        ################ global routing setting #################
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        # router
        assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)
        if self.soft_routing: # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing: # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)

        # qkv mapping (shared by both global routing and local attention)
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')

        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity': # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            # assert self.kv_downsample_ratio is not None
            # assert self.kv_downsample_kenel is not None
            # TODO: fracpool
            # 1. kernel size should be input size dependent
            # 2. there is a random factor, need to avoid independent sampling for k and v 
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            # TODO: need to consider the case where k != v so that need two downsample modules
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        # softmax for local attention
        self.attn_act = nn.Softmax(dim=-1)

        self.auto_pad=auto_pad

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139307690

相关文章
|
6月前
|
机器学习/深度学习 网络架构 计算机视觉
YOLOv5改进有效涨点系列->适合多种检测场景的BiFormer注意力机制(Bi-level Routing Attention)
YOLOv5改进有效涨点系列->适合多种检测场景的BiFormer注意力机制(Bi-level Routing Attention)
325 0
|
6月前
|
机器学习/深度学习 网络架构 计算机视觉
YOLOv8改进有效涨点系列->适合多种检测场景的BiFormer注意力机制(Bi-level Routing Attention)
YOLOv8改进有效涨点系列->适合多种检测场景的BiFormer注意力机制(Bi-level Routing Attention)
368 0
|
监控 安全 BI
Bi质押系统智能合约开发逻辑规则及代码示例
Bi质押系统智能合约开发逻辑规则及代码示例
|
机器学习/深度学习 存储 自然语言处理
Bi-SimCut: A Simple Strategy for Boosting Neural Machine Translation 论文笔记
Bi-SimCut: A Simple Strategy for Boosting Neural Machine Translation 论文笔记
|
机器学习/深度学习 传感器 算法
【GRU时序预测】基于双向门控循环单元Bi-GRU实现质量预测附matlab代码
【GRU时序预测】基于双向门控循环单元Bi-GRU实现质量预测附matlab代码
|
BI
《BI项目笔记》创建时间维度(2)
原文:《BI项目笔记》创建时间维度(2) 创建步骤:   序号 选择的属性 重命名后的名称 属性类别 1 Date...
803 0
|
SQL 存储 BI
《BI项目笔记》数据源视图设置
原文:《BI项目笔记》数据源视图设置 目的数据源视图是物理源数据库和分析维度与多维数据集之间的逻辑数据模型。在创建数据源视图时,需要在源数据库中指定包含创建维度和多维数据集所需要的数据表格和视图。BIDS与数据库连接,读取表格和视图定义,并在数据源视图中存储元数据。
1154 0
|
BI Go 存储
《BI项目笔记》创建时间维度(1)
原文:《BI项目笔记》创建时间维度(1) SSAS Date 维度基本上在所有的 Cube 设计过程中都存在,很难见到没有时间维度的 OLAP 数据库。但是根据不同的项目需求, Date 维度的设计可能不大相同,所以在设计时间维度的时候需要搞清楚几个问题: 你的业务涉及到的最低的细节级别是什么?比如按季度查看报表还是按月份,或者按周,或者再甚者按天。
883 0
|
BI 数据处理
《BI项目笔记》基于雪花模型的维度设计
原文:《BI项目笔记》基于雪花模型的维度设计 GBGradeCode 外键关系: 1 烟叶等级 T_GBGradeCode.
812 0
|
BI 数据库
《BI项目笔记》创建父子维度
原文:《BI项目笔记》创建父子维度 创建步骤: 而ParentOriginID其实就是对应的ParentOriginID,它的 Usage 必须是 Parent 才能表示这样的一个父子维度。 查看OriginID属性, Usage 是 Key。
1099 0

热门文章

最新文章