【PyTorch】按照论文思想实现通道和空间两种注意力机制

简介: 【PyTorch】按照论文思想实现通道和空间两种注意力机制
import torch
from torch import nn
class ChannelAttention(nn.Module):
    # ratio表示MLP中,中间层in_planes缩小的比例
    def __init__(self, in_plances, ratio=16) -> None:
        super().__init__()
        self.max_pool = nn.AdaptiveMaxPool2d((1,1))
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        '''
        (1) in_plances / ratio, 其结果为小数,导致模型报错;
        (2) in_plances // ratio, 向下取整;
        (3) Conv2d中bias为False主要是为了模拟MLP多层感知机的功能;
        '''
        self.mlp = nn.Sequential( # 此处没有中括号
            nn.Conv2d(in_plances, in_plances // ratio, 1, bias=False), # 此处为什么卷积不需要偏置,是为了模拟FC
            nn.ReLU(),
            nn.Conv2d(in_plances // ratio, in_plances, 1, bias=False) # python 中/与//的区别
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x1 = self.max_pool(x)
        x1 = self.mlp(x1)
        x2 = self.avg_pool(x)
        x2 = self.mlp(x2)
        # 此处直接相加,而不是拼接
        # torch.cat(x1, x2)
        out = x1 + x2
        out = self.sigmoid(out)
        return out
class SpatialAttention(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv2d = nn.Conv2d(2, 1, 7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()        
    def forward(self, x):
        '''
        注意,此处并不是简单的最大值和均值池化操作,而是cross channel的;
        '''
        avg_pool = torch.mean(x, dim=1, keepdim=True) # Bx1xHxW
        max_pool, _ = torch.max(x, dim=1, keepdim=True) # Bx1xHxW, 此处非常容易出错,少_
        # Bx2xHxW
        out = torch.cat([avg_pool, max_pool], dim=1)
        out = self.conv2d(out)
        out = self.sigmoid(out)
        return out
if __name__ == '__main__':
    from torchinfo import summary
    # import hiddenlayer as h
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print('Channel Attention')
    layer = ChannelAttention(32).to(device)
    summary(layer, (1, 32, 224, 224))
    print('Spatial Attention')
    layer = SpatialAttention().to(device)
    summary(layer, (1, 32, 224, 224))
    # graph = h.build_graph(layer, torch.zeros([1, 32, 224, 224]))
    # graph.theme = h.graph.THEMES['blue'].copy()
    # graph.save('test.png')
    print('done!')
目录
相关文章
|
2月前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch FlexAttention技术实践:基于BlockMask实现因果注意力与变长序列处理
本文介绍了如何使用PyTorch 2.5及以上版本中的FlexAttention和BlockMask功能,实现因果注意力机制与填充输入的处理。通过attention-gym仓库安装相关工具,并详细展示了MultiheadFlexAttention类的实现,包括前向传播函数、因果掩码和填充掩码的生成方法。实验设置部分演示了如何组合这两种掩码并应用于多头注意力模块,最终通过可视化工具验证了实现的正确性。该方法适用于处理变长序列和屏蔽未来信息的任务。
116 17
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
本文介绍了几种常用的计算机视觉注意力机制及其PyTorch实现,包括SENet、CBAM、BAM、ECA-Net、SA-Net、Polarized Self-Attention、Spatial Group-wise Enhance和Coordinate Attention等,每种方法都附有详细的网络结构说明和实验结果分析。通过这些注意力机制的应用,可以有效提升模型在目标检测任务上的性能。此外,作者还提供了实验数据集的基本情况及baseline模型的选择与实验结果,方便读者理解和复现。
329 0
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
|
5月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
346 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
5月前
|
机器学习/深度学习 存储 数据可视化
以pytorch的forward hook为例探究hook机制
【10月更文挑战第10天】PyTorch 的 Hook 机制允许用户在不修改模型代码的情况下介入前向和反向传播过程,适用于模型可视化、特征提取及梯度分析等任务。通过注册 `forward hook`,可以在模型前向传播过程中插入自定义操作,如记录中间层输出。使用时需注意输入输出格式及计算资源占用。
113 1
|
5月前
|
机器学习/深度学习 监控 PyTorch
以pytorch的forward hook为例探究hook机制
【10月更文挑战第9天】PyTorch中的Hook机制允许在不修改模型代码的情况下,获取和修改模型中间层的信息,如输入和输出等,适用于模型可视化、特征提取及梯度计算。Forward Hook在前向传播后触发,可用于特征提取和模型监控。实现上,需定义接收模块、输入和输出参数的Hook函数,并将其注册到目标层。与Backward Hook相比,前者关注前向传播,后者侧重反向传播和梯度处理,两者共同增强了对模型内部运行情况的理解和控制。
100 3
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
CNN中的注意力机制综合指南:从理论到Pytorch代码实现
注意力机制已成为深度学习模型的关键组件,尤其在卷积神经网络(CNN)中发挥了重要作用。通过使模型关注输入数据中最相关的部分,注意力机制显著提升了CNN在图像分类、目标检测和语义分割等任务中的表现。本文将详细介绍CNN中的注意力机制,包括其基本概念、不同类型(如通道注意力、空间注意力和混合注意力)以及实际实现方法。此外,还将探讨注意力机制在多个计算机视觉任务中的应用效果及其面临的挑战。无论是图像分类还是医学图像分析,注意力机制都能显著提升模型性能,并在不断发展的深度学习领域中扮演重要角色。
229 10
|
8月前
|
机器学习/深度学习 数据采集 自然语言处理
注意力机制中三种掩码技术详解和Pytorch实现
**注意力机制中的掩码在深度学习中至关重要,如Transformer模型所用。掩码类型包括:填充掩码(忽略填充数据)、序列掩码(控制信息流)和前瞻掩码(自回归模型防止窥视未来信息)。通过创建不同掩码,如上三角矩阵,模型能正确处理变长序列并保持序列依赖性。在注意力计算中,掩码修改得分,确保模型学习的有效性。这些技术在现代NLP和序列任务中是核心组件。**
436 12
|
10月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
1058 4
|
10月前
|
存储 并行计算 Java
一文读懂 PyTorch 显存管理机制
一文读懂 PyTorch 显存管理机制
541 1
|
10月前
|
机器学习/深度学习 自然语言处理 PyTorch
Pytorch图像处理注意力机制SENet CBAM ECA模块解读
注意力机制最初是为了解决自然语言处理(NLP)任务中的问题而提出的,它使得模型能够在处理序列数据时动态地关注不同位置的信息。随后,注意力机制被引入到图像处理任务中,为深度学习模型提供了更加灵活和有效的信息提取能力。注意力机制的核心思想是根据输入数据的不同部分,动态地调整模型的注意力,从而更加关注对当前任务有用的信息。
518 0