计算机视觉 - Attention机制(附代码)

简介: 计算机视觉 - Attention机制(附代码)

1.Attention简介


Attention中文意思为注意力,这个机制放到计算机视觉里,类似于给我们看一张美女帅哥的图片,我们第一眼首先关注的地方是这个人的哪里呢😏

你们第一眼看的是哪里呢😏

最早attention机制就应用到计算机视觉中,这里说的机制,其实就是神经网络中一个模块,类似于U-Net加上attention机制的变化。


85d68f1184ab499295f7d82bea973b7b.png


3e365869a0114eb6849a9f815b5b9179.png


看出什么变化了吗,其实就是在原始的网络结构增加一些结构模块。

随着NLP领域的发展,也开始应用了atteniton机制,除了这个还有循环神经网络(RNNs)和门控循环单元(GRUs)﹑长短期记忆(LSTMs)﹑序列对序列(Seq2Seq)﹑记忆网络(Memory Networks)等。这些都是Encoder-Decoder的不同框架。


但是attention是可以脱离Encoder-Decoder,被其他模型框架使用的。


2.Attention原理


在平时最常用的淘宝,得物的照片识别,其实算法都是使用attention这个机制的。

Attention 原理的3步分解

b75a736969c94ae7a7c2d0f309cc7ebd.png


第一步: query 和 key 进行相矩阵相乘

第二步:将矩阵相乘得到的结果根据不同权重进行归一化

第三步:将结果和 value 再进行一次矩阵相乘


这里步骤中提到的query、key、value其实就是我们的feature maps分别跟1x1的卷积核卷积得到的三个向量。如下图所示。

8566f172803a47b0921045c3b09cf955.png

整体步骤,可以这样理解:


第一步我们先生成一个包含像素特征的图像value

第二步,我们生成出我们需要找的特征图像query,比如说得物,我们需要找到图像中的鞋的细节(鞋底,鞋带。。。)。

第三步,我们给图像中所有的特征都做一个编号key。

我们的方式就是通过query去查找到图像中key,提取我们需要的key,并与value结合,利用权重,得到实际我们想要查找的图像关键区域。


🤩说白了,在attention机制就是一种特征图的权重分布,把有用的特征权重加大,没有的特征权重加小,再用学出来的权重施加在原特征图之上最后进行加权求和。


3.Attention的不同类型


目前attention已经应用到计算机视觉,自然语言处理等多个领域,这些不同领域的应用,虽然attention的结构不变,但是其中的query、key、value的计算方式是不同的。计算区域也不同(一个卷积核乘积,不是所有的feature maps都做乘积)。

e399114be9ac4f00a948bbf103eabe39.png

前面attention原理,介绍的是attention的通用版本。这里我只提计算机视觉方面的attention,在计算机视觉中,主要有三种attention,分别为:


spatial attention:对于卷积神经网络,CNN每一层都会输出一个C x H x W的特征图,C就是通道,同时也代表卷积核的数量,亦为特征的数量,H 和W就是原始图片经过压缩后的图的高度和宽度,spatial attention就是对于所有的通道,在二维平面上,对H x W尺寸的特征图学习到一个权重,对每个像素都会学习到一个权重。你可以想象成一个像素是C维的一个向量,深度是C,在C个维度上,权重都是一样的,但是在平面上,权重不一样。

channel attention:就是对每个C(通道),在channel维度上,学习到不同的权重,平面维度上权重相同。所以基于通道域的注意力通常是对一个通道内的信息直接全局平均池化,而忽略每一个通道内的局部信息。SENet算法就是使用的channel attention。

spatial attention与channel attention融合:CBAM(Convolutional Block Attention Module)[5] 是其中的代表性网络,结构如下:


8985b2ec1a3e4f9da7177c54cf73a6ec.png

其中Channel Attention Module模块:

d9fd8d89497f472396e596013cdc69d0.png

同时使用最大 pooling 和均值 pooling 算法,然后经过几个 MLP 层获得变换结果,最后分别应用于两个通道,使用 sigmoid 函数得到通道的 attention 结果。

其中Spatial Attention Module模块:

image.png

首先将通道本身进行降维,分别获取最大池化和均值池化结果,然后拼接成一个特征图,再使用一个卷积层进行学习。

这两种机制,分别学习了通道的重要性和空间的重要性,还可以很容易地嵌入到任何已知的框架中。


4.CBAM实现(Pytorch)


CBAM模块详细:

4214c1cfdeb04f15b45342f3ddc52cde.png


其中Channel Attention模块:

其中Spatial Attention模块:

代码如下:


import torch 
import torch.nn as nn
import torchvision
#ratio 为通道数
class ChannelAttention(nn.Moudel):
  def __init__(self, channel, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        print(avgout.shape)
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out
class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()
    def forward(self, x):
        out = self.channel_attention(x) * x
        print('outchannels:{}'.format(out.shape))
        out = self.spatial_attention(out) * out
        return out
相关文章
|
机器学习/深度学习 人工智能 算法
【AAAI 2024】再创佳绩!阿里云人工智能平台PAI多篇论文入选
阿里云人工智能平台PAI发表的多篇论文在AAAI-2024上正式亮相发表。AAAI是由国际人工智能促进协会主办的年会,是人工智能领域中历史最悠久、涵盖内容最广泛的国际顶级学术会议之一,也是中国计算机学会(CCF)推荐的A类国际学术会议。论文成果是阿里云与浙江大学、华南理工大学联合培养项目等共同研发,深耕以通用人工智能(AGI)为目标的一系列基础科学与工程问题,包括多模态理解模型、小样本类增量学习、深度表格学习和文档版面此次入选意味着阿里云人工智能平台PAI自研的深度学习算法达到了全球业界先进水平,获得了国际学者的认可,展现了阿里云人工智能技术创新在国际上的竞争力。
|
机器学习/深度学习 PyTorch TensorFlow
Pytorch学习笔记(二):nn.Conv2d()函数详解
这篇文章是关于PyTorch中nn.Conv2d函数的详解,包括其函数语法、参数解释、具体代码示例以及与其他维度卷积函数的区别。
2767 0
Pytorch学习笔记(二):nn.Conv2d()函数详解
|
机器学习/深度学习 监控 算法
基于计算机视觉(opencv)的运动计数(运动辅助)系统-源码+注释+报告
基于计算机视觉(opencv)的运动计数(运动辅助)系统-源码+注释+报告
380 3
|
8月前
|
开发框架 JavaScript 前端开发
鸿蒙开发:什么是ArkTs?
本小结主要简单介绍了ArkTs语言的相关知识,都是一些概念性质的内容,大家作为一个了解即可
483 61
|
Python
Python教程:@符号的用法
@ 符号在 Python 中最常见的使用情况是在装饰器中。一个装饰器可以让你改变一个函数或类的行为。 @ 符号也可以作为一个数学运算符使用,因为它可以在Python中乘以矩阵。本教程将教你如何使用 Python 的@ 符号。
1423 0
|
Docker Windows 容器
在Docker中的Neo4j导入CSV文件报错:Couldn‘t load the external resource at: file:/...解决办法
在Docker中的Neo4j导入CSV文件报错:Couldn‘t load the external resource at: file:/...解决办法
1013 0
在Docker中的Neo4j导入CSV文件报错:Couldn‘t load the external resource at: file:/...解决办法
|
机器学习/深度学习 自然语言处理 并行计算
【深度学习】Attention的原理、分类及实现
文章详细介绍了注意力机制(Attention)的原理、不同类型的分类以及如何在Keras中实现Attention。文章涵盖了Attention的基本概念、计算区域、所用信息、结构层次等方面,并提供了实现示例。
1926 0
|
机器学习/深度学习 运维 数据可视化
深度学习之热力图
热力图(Heatmap)在深度学习中是用于可视化数据、模型预测结果或特征的重要工具。它通过颜色的变化来表示数值的大小,便于直观地理解数据的分布、模型的关注区域以及特征的重要性。以下是深度学习中热力图的主要应用和特点。
997 0
|
人工智能 安全 API
发现一款提高工作效率的利器——ONLYOFFICE办公软件
发现一款提高工作效率的利器——ONLYOFFICE办公软件
814 1
|
SQL 人工智能 关系型数据库
【开源项目推荐】-支持GPT的智能数据库客户端与报表工具——Chat2DB
【开源项目推荐】-支持GPT的智能数据库客户端与报表工具——Chat2DB
885 3