YOLOv8改进算法之添加CA注意力机制

简介: CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系

1. CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:



1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:


C 是通道数,表示特征图中的不同特征通道。

H 是高度,表示特征图的垂直维度。

W 是宽度,表示特征图的水平维度。


2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:


在宽度方向上的平均池化得到特征映射 [C, H, 1]。

在高度方向上的平均池化得到特征映射 [C, 1, W]。

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。


3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。


4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。


5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:


一个分支得到特征层 [C, 1, H]。

另一个分支得到特征层 [C, 1, W]。


6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1] 和 [C, 1, W]。


7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。


8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。



2. YOLOv8添加CA注意力机制

加入注意力机制,在ultralytics包中的nn包的modules里添加CA注意力模块,我这里选择在conv.py文件中添加CA注意力机制。



CA注意力机制代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self, x):
        return self.relu(x + 3) / 6
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)
    def forward(self, x):
        return x * self.sigmoid(x)
class CoordAtt(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out



CA注意力机制的注册和引用如下:


ultralytics/nn/modules/_init_.py文件中:



 ultralytics/nn/tasks.py文件夹中:



在tasks.py中的parse_model中添加如下代码:

elif m in {CoordAtt}:
            args=[ch[f],*args]


新建相应的yolov8s-CA.yaml文件,代码如下:


# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 8], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 15], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)


在main.py文件中进行训练:


if __name__ == '__main__':
    # 使用yaml配置文件来创建模型,并导入预训练权重.
    model = YOLO('ultralytics/cfg/models/v8/yolov8s-CA.yaml')
    # model.load('yolov8n.pt')
    model.train(**{'cfg': 'ultralytics/cfg/default.yaml', 'data': 'dataset/data.yaml'})


相关文章
|
机器学习/深度学习 监控 算法
yolov8+多算法多目标追踪+实例分割+目标检测+姿态估计(代码+教程)
yolov8+多算法多目标追踪+实例分割+目标检测+姿态估计(代码+教程)
|
机器学习/深度学习 数据采集 算法
解码癌症预测的密码:可解释性机器学习算法SHAP揭示XGBoost模型的预测机制
解码癌症预测的密码:可解释性机器学习算法SHAP揭示XGBoost模型的预测机制
1374 0
|
6月前
|
机器学习/深度学习 算法 数据可视化
从另一个视角看Transformer:注意力机制就是可微分的k-NN算法
注意力机制可理解为一种“软k-NN”:查询向量通过缩放点积计算与各键的相似度,softmax归一化为权重,对值向量加权平均。1/√d缩放防止高维饱和,掩码控制信息流动(如因果、填充)。不同相似度函数(点积、余弦、RBF)对应不同归纳偏置,多头则在多个子空间并行该过程。
453 6
|
10月前
|
存储 监控 算法
基于 C++ 哈希表算法实现局域网监控电脑屏幕的数据加速机制研究
企业网络安全与办公管理需求日益复杂的学术语境下,局域网监控电脑屏幕作为保障信息安全、规范员工操作的重要手段,已然成为网络安全领域的关键研究对象。其作用类似网络空间中的 “电子眼”,实时捕获每台电脑屏幕上的操作动态。然而,面对海量监控数据,实现高效数据存储与快速检索,已成为提升监控系统性能的核心挑战。本文聚焦于 C++ 语言中的哈希表算法,深入探究其如何成为局域网监控电脑屏幕数据处理的 “加速引擎”,并通过详尽的代码示例,展现其强大功能与应用价值。
212 2
|
12月前
|
存储 监控 算法
基于 PHP 二叉搜索树算法的内网行为管理机制探究
在当今数字化网络环境中,内网行为管理对于企业网络安全及高效运营具有至关重要的意义。它涵盖对企业内部网络中各类行为的监测、分析与管控。在内网行为管理技术体系里,算法与数据结构扮演着核心角色。本文将深入探究 PHP 语言中的二叉搜索树算法于内网行为管理中的应用。
167 4
|
11月前
|
存储 算法 物联网
解析局域网内控制电脑机制:基于 Go 语言链表算法的隐秘通信技术探究
数字化办公与物联网蓬勃发展的时代背景下,局域网内计算机控制已成为提升工作效率、达成设备协同管理的重要途径。无论是企业远程办公时的设备统一调度,还是智能家居系统中多设备间的联动控制,高效的数据传输与管理机制均构成实现局域网内计算机控制功能的核心要素。本文将深入探究 Go 语言中的链表数据结构,剖析其在局域网内计算机控制过程中,如何达成数据的有序存储与高效传输,并通过完整的 Go 语言代码示例展示其应用流程。
217 0
|
存储 监控 算法
公司监控上网软件架构:基于 C++ 链表算法的数据关联机制探讨
在数字化办公时代,公司监控上网软件成为企业管理网络资源和保障信息安全的关键工具。本文深入剖析C++中的链表数据结构及其在该软件中的应用。链表通过节点存储网络访问记录,具备高效插入、删除操作及节省内存的优势,助力企业实时追踪员工上网行为,提升运营效率并降低安全风险。示例代码展示了如何用C++实现链表记录上网行为,并模拟发送至服务器。链表为公司监控上网软件提供了灵活高效的数据管理方式,但实际开发还需考虑安全性、隐私保护等多方面因素。
251 0
公司监控上网软件架构:基于 C++ 链表算法的数据关联机制探讨
|
机器学习/深度学习 算法 计算机视觉
[YOLOv8/YOLOv7/YOLOv5系列算法改进NO.5]改进特征融合网络PANET为BIFPN(更新添加小目标检测层yaml)
本文介绍了改进YOLOv5以解决处理复杂背景时可能出现的错漏检问题。
641 5
|
机器学习/深度学习 人工智能 文字识别
一种基于YOLOv8改进的高精度红外小目标检测算法 (原创自研)
【7月更文挑战第2天】 💡💡💡创新点: 1)SPD-Conv特别是在处理低分辨率图像和小物体等更困难的任务时优势明显; 2)引入Wasserstein Distance Loss提升小目标检测能力; 3)YOLOv8中的Conv用cvpr2024中的DynamicConv代替;
1627 4
|
算法 安全
金石原创 |【分布式技术专题】「分布式技术架构」一文带你厘清分布式事务协议及分布式一致性协议的算法原理和核心流程机制(Paxos篇)
金石原创 |【分布式技术专题】「分布式技术架构」一文带你厘清分布式事务协议及分布式一致性协议的算法原理和核心流程机制(Paxos篇)
789 1
金石原创 |【分布式技术专题】「分布式技术架构」一文带你厘清分布式事务协议及分布式一致性协议的算法原理和核心流程机制(Paxos篇)

热门文章

最新文章