【YOLOv8改进】HAT(Hybrid Attention Transformer,)混合注意力机制 (论文笔记+引入代码)

简介: YOLO目标检测专栏介绍了YOLO系列的改进方法和实战应用,包括卷积、主干网络、注意力机制和检测头的创新。提出的Hybrid Attention Transformer (HAT)结合通道注意力和窗口自注意力,激活更多像素以提升图像超分辨率效果。通过交叉窗口信息聚合和同任务预训练策略,HAT优化了Transformer在低级视觉任务中的性能。实验显示,HAT在图像超分辨率任务上显著优于现有方法。模型结构包含浅层和深层特征提取以及图像重建阶段。此外,提供了HAT模型的PyTorch实现代码。更多详细配置和任务说明可参考相关链接。

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

摘要

基于Transformer的方法在低级视觉任务中表现出色,例如图像超分辨率。然而,通过归因分析,我们发现这些网络只能利用输入信息的有限空间范围。这表明Transformer在现有网络中的潜力尚未完全发挥。为了激活更多的输入像素以获得更好的重建效果,我们提出了一种新颖的混合注意力Transformer(Hybrid Attention Transformer, HAT)。它结合了通道注意力和基于窗口的自注意力机制,从而利用了它们能够利用全局统计信息和强大的局部拟合能力的互补优势。此外,为了更好地聚合跨窗口信息,我们引入了一个重叠交叉注意模块,以增强相邻窗口特征之间的交互。在训练阶段,我们还采用了同任务预训练策略,以进一步挖掘模型的潜力。大量实验表明了所提模块的有效性,我们进一步扩大了模型规模,证明了该任务的性能可以大幅提高。我们的方法整体上显著优于最先进的方法,超过了1dB。

创新点

  1. 更多像素的激活:通过结合不同的注意力机制,HAT能够激活更多的输入像素,这在图像超分辨率领域尤为重要,因为它直接关系到重建图像的细节和质量。

  2. 交叉窗口信息的有效聚合:通过重叠交叉注意力模块,HAT模型能够更有效地聚合跨窗口的信息,避免了传统Transformer模型中窗口间信息隔离的问题。

  3. 针对图像超分辨率优化的预训练策略:HAT采用的同任务预训练策略针对性强,能够更有效地利用大规模数据预训练的优势,提高模型在特定超分辨率任务上的表现。

    HAT模型整体架构:

  4. 浅层特征提取(Shallow Feature Extraction)
    输入图像首先通过一个浅层特征提取模块,该模块使用卷积操作提取初始的低层次特征。

  5. 深层特征提取(Deep Feature Extraction)
    提取的特征输入到多个Residual Hybrid Attention Groups (RHAG)中进行深层特征提取。每个RHAG由多个Hybrid Attention Blocks (HAB)和一个Overlapping Cross-Attention Block (OCAB)组成。

  6. 图像重建(Image Reconstruction)
    最后,经过深层特征提取的特征通过一个图像重建模块,将高层次特征转化为输出的超分辨率图像。

yolov8 引入

class HAT(nn.Module):
    r"""混合注意力变换器 (Hybrid Attention Transformer)
        该PyTorch实现基于 `Activating More Pixels in Image Super-Resolution Transformer`。
        部分代码基于SwinIR。
    参数:
        img_size (int | tuple(int)): 输入图像大小。默认值64
        patch_size (int | tuple(int)): Patch大小。默认值1
        in_chans (int): 输入图像通道数。默认值3
        embed_dim (int): Patch嵌入维度。默认值96
        depths (tuple(int)): 每个Swin Transformer层的深度。
        num_heads (tuple(int)): 不同层的注意力头数量。
        window_size (int): 窗口大小。默认值7
        mlp_ratio (float): MLP隐藏层维度与嵌入维度的比例。默认值4
        qkv_bias (bool): 如果为True,为查询、键和值添加可学习的偏差。默认值True
        qk_scale (float): 如果设置,覆盖head_dim ** -0.5的默认qk缩放比例。默认值None
        drop_rate (float): Dropout率。默认值0
        attn_drop_rate (float): 注意力dropout率。默认值0
        drop_path_rate (float): 随机深度率。默认值0.1
        norm_layer (nn.Module): 规范化层。默认值nn.LayerNorm。
        ape (bool): 如果为True,为Patch嵌入添加绝对位置嵌入。默认值False
        patch_norm (bool): 如果为True,在Patch嵌入后添加规范化。默认值True
        use_checkpoint (bool): 是否使用checkpointing来节省内存。默认值False
        upscale: 上采样因子。用于图像SR的2/3/4/8,1用于降噪和压缩伪影减少
        img_range: 图像范围。1或255。
        upsampler: 重建模块。'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
        resi_connection: 残差连接前的卷积块。'1conv'/'3conv'
    """

    def __init__(self,
                 in_chans=3,
                 img_size=64,
                 patch_size=1,
                 embed_dim=96,
                 depths=(6, 6, 6, 6),
                 num_heads=(6, 6, 6, 6),
                 window_size=7,
                 compress_ratio=3,
                 squeeze_factor=30,
                 conv_scale=0.01,
                 overlap_ratio=0.5,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm,
                 ape=False,
                 patch_norm=True,
                 use_checkpoint=False,
                 upscale=2,
                 img_range=1.,
                 upsampler='',
                 resi_connection='1conv',
                 **kwargs):
        super(HAT, self).__init__()

        self.window_size = window_size
        self.shift_size = window_size // 2
        self.overlap_ratio = overlap_ratio

        num_in_ch = in_chans
        num_out_ch = in_chans
        num_feat = 64
        self.img_range = img_range
        if in_chans == 3:
            rgb_mean = (0.4488, 0.4371, 0.4040)
            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
        else:
            self.mean = torch.zeros(1, 1, 1, 1)
        self.upscale = upscale
        self.upsampler = upsampler

        # 相对位置索引
        relative_position_index_SA = self.calculate_rpi_sa()
        relative_position_index_OCA = self.calculate_rpi_oca()
        self.register_buffer('relative_position_index_SA', relative_position_index_SA)
        self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)

        # ------------------------- 1,浅层特征提取 ------------------------- #
        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

        # ------------------------- 2,深层特征提取 ------------------------- #
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = embed_dim
        self.mlp_ratio = mlp_ratio

        # 将图像分割为非重叠patch
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=embed_dim,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # 将非重叠的patch合并为图像
        self.patch_unembed = PatchUnEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=embed_dim,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # 绝对位置嵌入
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # 随机深度
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度衰减规则

        # 构建残差混合注意力组 (RHAG)
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = RHAG(
                dim=embed_dim,
                input_resolution=(patches_resolution[0], patches_resolution[1]),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                compress_ratio=compress_ratio,
                squeeze_factor=squeeze_factor,
                conv_scale=conv_scale,
                overlap_ratio=overlap_ratio,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # 对SR结果无影响
                norm_layer=norm_layer,
                downsample=None,
                use_checkpoint=use_checkpoint,
                img_size=img_size,
                patch_size=patch_size,
                resi_connection=resi_connection)
            self.layers.append(layer)
        self.norm = norm_layer(self.num_features)

        # 构建深层特征提取中的最后一个卷积层
        if resi_connection == '1conv':
            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        elif resi_connection == 'identity':
            self.conv_after_body = nn.Identity()

        # ------------------------- 3,高质量图像重建 ------------------------- #
        if self.upsampler == 'pixelshuffle':
            # 用于经典SR
            self.conv_before_upsample = nn.Sequential(
                nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
            self.upsample = Upsample(upscale, num_feat)
            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.apply(self._init_weights)

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139142532

相关文章
|
29天前
|
机器学习/深度学习 算法 计算机视觉
【YOLOv8改进】CPCA(Channel prior convolutional attention)中的通道注意力,增强特征表征能力 (论文笔记+引入代码)
该专栏聚焦YOLO目标检测的创新改进与实战,介绍了一种针对医学图像分割的通道优先卷积注意力(CPCA)方法。CPCA结合通道和空间注意力,通过多尺度深度卷积提升性能。提出的CPCANet网络在有限计算资源下,于多个数据集上展现优越分割效果。代码已开源。了解更多详情,请访问提供的专栏链接。
|
29天前
|
机器学习/深度学习 编解码 算法
【YOLOv8改进】Polarized Self-Attention: 极化自注意力 (论文笔记+引入代码)
该专栏专注于YOLO目标检测算法的创新改进和实战应用,包括卷积、主干网络、注意力机制和检测头的改进。作者提出了一种名为极化自注意(PSA)块,结合极化过滤和增强功能,提高像素级回归任务的性能,如关键点估计和分割。PSA通过保持高分辨率和利用通道及空间注意力,减少了信息损失并适应非线性输出分布。实验证明,PSA能提升标准基线和最新技术1-4个百分点。代码示例展示了如何在YOLOv8中实现PSA模块。更多详细信息和配置可在提供的链接中找到。
|
29天前
|
机器学习/深度学习 计算机视觉 知识图谱
【YOLOv8改进】ACmix(Mixed Self-Attention and Convolution) (论文笔记+引入代码)
YOLO目标检测专栏探讨了YOLO的改进,包括卷积和自注意力机制的创新结合。研究发现两者在计算上存在关联,卷积可分解为1×1卷积,自注意力也可视为1×1卷积的变形。由此提出ACmix模型,它整合两种范式,降低计算开销,同时提升图像识别和下游任务的性能。ACmix优化了移位操作,采用模块化设计,实现两种技术优势的高效融合。代码和预训练模型可在相关GitHub和MindSpore模型库找到。 yolov8中引入了ACmix模块,详细配置参见指定链接。
|
29天前
|
机器学习/深度学习 编解码 计算机视觉
【YOLOv8改进】D-LKA Attention:可变形大核注意力 (论文笔记+引入代码)
YOLO目标检测专栏探讨了Transformer在医学图像分割的进展,但计算需求限制了模型的深度和分辨率。为此,提出了可变形大核注意力(D-LKA Attention),它使用大卷积核捕捉上下文信息,通过可变形卷积适应数据模式变化。D-LKA Net结合2D和3D版本的D-LKA Attention,提升了医学分割性能。YOLOv8引入了可变形卷积层以增强目标检测的准确性。相关代码和任务配置可在作者博客找到。
|
29天前
|
机器学习/深度学习 计算机视觉
【YOLOv8改进】EMA(Efficient Multi-Scale Attention):基于跨空间学习的高效多尺度注意力 (论文笔记+引入代码)
YOLO目标检测专栏介绍了创新的多尺度注意力模块EMA,它强化通道和空间信息处理,同时降低计算负担。EMA模块通过通道重塑和并行子网络优化特征表示,增强长距离依赖建模,在保持效率的同时提升模型性能。适用于图像分类和目标检测任务,尤其在YOLOv8中表现出色。代码实现和详细配置可在文中链接找到。
|
29天前
|
机器学习/深度学习 测试技术 计算机视觉
【YOLOv8改进】DAT(Deformable Attention):可变性注意力 (论文笔记+引入代码)
YOLO目标检测创新改进与实战案例专栏探讨了YOLO的有效改进,包括卷积、主干、注意力和检测头等机制的创新,以及目标检测分割项目的实践。专栏介绍了Deformable Attention Transformer,它解决了Transformer全局感受野带来的问题,通过数据依赖的位置选择、灵活的偏移学习和全局键共享,聚焦相关区域并捕获更多特征。模型在多个基准测试中表现优秀,代码可在GitHub获取。此外,文章还展示了如何在YOLOv8中应用Deformable Attention。
|
30天前
|
机器学习/深度学习 测试技术 网络架构
【YOLOv8改进】MSCA: 多尺度卷积注意力 (论文笔记+引入代码).md
SegNeXt是提出的一种新的卷积网络架构,专注于语义分割任务,它证明了卷积注意力在编码上下文信息上优于自注意力机制。该模型通过结合深度卷积、多分支深度卷积和1x1逐点卷积实现高效性能提升。在多个基准测试中,SegNeXt超越了现有最佳方法,如在Pascal VOC 2012上达到90.6%的mIoU,参数量仅为EfficientNet-L2 w/ NAS-FPN的1/10。此外,它在ADE20K数据集上的mIoU平均提高了2.0%,同时保持相同的计算量。YOLOv8中引入了名为MSCAAttention的模块,以利用这种多尺度卷积注意力机制。更多详情和配置可参考相关链接。
|
6月前
|
机器学习/深度学习 人工智能 关系型数据库
简化版Transformer :Simplifying Transformer Block论文详解
在这篇文章中我将深入探讨来自苏黎世联邦理工学院计算机科学系的Bobby He和Thomas Hofmann在他们的论文“Simplifying Transformer Blocks”中介绍的Transformer技术的进化步骤。这是自Transformer 开始以来,我看到的最好的改进。
71 0
|
9月前
|
机器学习/深度学习 人工智能 自然语言处理
解码注意力Attention机制:从技术解析到PyTorch实战
解码注意力Attention机制:从技术解析到PyTorch实战
246 0
|
机器学习/深度学习 编解码 自然语言处理
MoA-Transformer | Swin-Transformer应该如何更好地引入全局信息?(一)
MoA-Transformer | Swin-Transformer应该如何更好地引入全局信息?(一)
135 0