摘要
得益于在通道或空间位置之间构建相互依赖关系的能力,注意力机制在最近被广泛研究并广泛应用于各种计算机视觉任务中。在本文中,我们研究了轻量但有效的注意力机制,并提出了三重注意力,这是一种通过使用三分支结构捕获跨维度交互来计算注意力权重的新方法。对于输入张量,三重注意力通过旋转操作及后续的残差变换构建维度间依赖关系,并以可忽略的计算开销编码通道间和空间信息。我们的方法简单且高效,可以作为附加模块轻松插入经典骨干网络中。我们在各种具有挑战性的任务中证明了我们方法的有效性,包括 ImageNet-1k 上的图像分类以及 MSCOCO 和 PASCAL VOC 数据集上的目标检测。此外,我们通过可视化检查 GradCAM 和 GradCAM++ 结果,提供了对三重注意力性能的广泛见解。我们方法的实证评估支持了在计算注意力权重时捕捉跨维度依赖关系的重要性。本文的代码可在 https://github.com/LandskapeAI/triplet-attention 公开获取。
文章链接
论文地址:论文地址
代码地址:代码地址
基本原理
给定一个输入张量
,首先将其传递到Triplet Attention模块中的三个分支中。\ 在第1个分支中,在H维度和C维度之间建立了交互:
为了实现这一点,输入张量 \chi 沿H轴逆时针旋转90°。这个旋转张量 \hat{\chi }{1} 表示为的形状为 (W×H×C) ,再然后经过Z-Pool后的张量 \hat{\chi }{1}^{
} 的shape为 (2×H×C) ,然后,通过内核大小为 k×k 的标准卷积层,再通过批处理归一化层,提供维数 (1×H×C) 的中间输出。然后,通过将张量通过sigmoid来生成的注意力权值。在最后输出是沿着H轴进行顺时针旋转90°保持和输入的shape一致。\ 在第2个分支中,在C维度和W维度之间建立了交互:
为了实现这一点,输入张量 \chi 沿W轴逆时针旋转90°。这个旋转张量 \hat{\chi }{2} 表示为的形状为 (H×C×W) ,再然后经过Z-Pool后的张量 \hat{\chi }{2}^{
} 的shape为 (2×C×W ) ,然后,通过内核大小为 k×k 的标准卷积层,再通过批处理归一化层,提供维数 (1×C×W) 的中间输出。然后,通过将张量通过sigmoid来生成的注意力权值。在最后输出是沿着W轴进行顺时针旋转90°保持和输入的shape一致。\ 在第3个分支中,在H维度和W维度之间建立了交互:
输入张量
的通道通过Z-pool将变量简化为2。将这个形状的简化张量 (2×H×W) 简化后通过核大小 k×k 定义的标准卷积层,然后通过批处理归一化层。输出通过sigmoid激活层生成形状为(1×H×W)的注意权值,并将其应用于输入
,得到结果
。然后通过简单的平均将3个分支产生的精细张量 (C×H×W) 聚合在一起。 最终输出的Tensor:
核心代码
import torch
import torch.nn as nn
class BasicConv(nn.Module):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
relu=True,
bn=True,
bias=False,
):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.bn = (
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn
else None
)
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat(
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid_(x_out)
return x * scale
class TripletAttention(nn.Module):
def __init__(
self,
gate_channels,
reduction_ratio=16,
pool_types=["avg", "max"],
no_spatial=False,
):
super(TripletAttention, self).__init__()
self.ChannelGateH = SpatialGate()
self.ChannelGateW = SpatialGate()
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
x_out1 = self.ChannelGateH(x_perm1)
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
x_out2 = self.ChannelGateW(x_perm2)
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_out = self.SpatialGate(x)
x_out = (1 / 3) * (x_out + x_out11 + x_out21)
else:
x_out = (1 / 2) * (x_out11 + x_out21)
return x_out
task与yaml配置
详见:https://blog.csdn.net/shangyanaf/article/details/139999693