YOLO目标检测创新改进与实战案例专栏
专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例
专栏链接: YOLO基础解析+创新改进+实战案例
摘要
基于Transformer的方法在低级视觉任务中表现出色,例如图像超分辨率。然而,通过归因分析,我们发现这些网络只能利用输入信息的有限空间范围。这表明Transformer在现有网络中的潜力尚未完全发挥。为了激活更多的输入像素以获得更好的重建效果,我们提出了一种新颖的混合注意力Transformer(Hybrid Attention Transformer, HAT)。它结合了通道注意力和基于窗口的自注意力机制,从而利用了它们能够利用全局统计信息和强大的局部拟合能力的互补优势。此外,为了更好地聚合跨窗口信息,我们引入了一个重叠交叉注意模块,以增强相邻窗口特征之间的交互。在训练阶段,我们还采用了同任务预训练策略,以进一步挖掘模型的潜力。大量实验表明了所提模块的有效性,我们进一步扩大了模型规模,证明了该任务的性能可以大幅提高。我们的方法整体上显著优于最先进的方法,超过了1dB。
创新点
更多像素的激活:通过结合不同的注意力机制,HAT能够激活更多的输入像素,这在图像超分辨率领域尤为重要,因为它直接关系到重建图像的细节和质量。
交叉窗口信息的有效聚合:通过重叠交叉注意力模块,HAT模型能够更有效地聚合跨窗口的信息,避免了传统Transformer模型中窗口间信息隔离的问题。
针对图像超分辨率优化的预训练策略:HAT采用的同任务预训练策略针对性强,能够更有效地利用大规模数据预训练的优势,提高模型在特定超分辨率任务上的表现。
HAT模型整体架构:
浅层特征提取(Shallow Feature Extraction):
输入图像首先通过一个浅层特征提取模块,该模块使用卷积操作提取初始的低层次特征。深层特征提取(Deep Feature Extraction):
提取的特征输入到多个Residual Hybrid Attention Groups (RHAG)中进行深层特征提取。每个RHAG由多个Hybrid Attention Blocks (HAB)和一个Overlapping Cross-Attention Block (OCAB)组成。图像重建(Image Reconstruction):
最后,经过深层特征提取的特征通过一个图像重建模块,将高层次特征转化为输出的超分辨率图像。
yolov8 引入
class HAT(nn.Module):
r"""混合注意力变换器 (Hybrid Attention Transformer)
该PyTorch实现基于 `Activating More Pixels in Image Super-Resolution Transformer`。
部分代码基于SwinIR。
参数:
img_size (int | tuple(int)): 输入图像大小。默认值64
patch_size (int | tuple(int)): Patch大小。默认值1
in_chans (int): 输入图像通道数。默认值3
embed_dim (int): Patch嵌入维度。默认值96
depths (tuple(int)): 每个Swin Transformer层的深度。
num_heads (tuple(int)): 不同层的注意力头数量。
window_size (int): 窗口大小。默认值7
mlp_ratio (float): MLP隐藏层维度与嵌入维度的比例。默认值4
qkv_bias (bool): 如果为True,为查询、键和值添加可学习的偏差。默认值True
qk_scale (float): 如果设置,覆盖head_dim ** -0.5的默认qk缩放比例。默认值None
drop_rate (float): Dropout率。默认值0
attn_drop_rate (float): 注意力dropout率。默认值0
drop_path_rate (float): 随机深度率。默认值0.1
norm_layer (nn.Module): 规范化层。默认值nn.LayerNorm。
ape (bool): 如果为True,为Patch嵌入添加绝对位置嵌入。默认值False
patch_norm (bool): 如果为True,在Patch嵌入后添加规范化。默认值True
use_checkpoint (bool): 是否使用checkpointing来节省内存。默认值False
upscale: 上采样因子。用于图像SR的2/3/4/8,1用于降噪和压缩伪影减少
img_range: 图像范围。1或255。
upsampler: 重建模块。'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: 残差连接前的卷积块。'1conv'/'3conv'
"""
def __init__(self,
in_chans=3,
img_size=64,
patch_size=1,
embed_dim=96,
depths=(6, 6, 6, 6),
num_heads=(6, 6, 6, 6),
window_size=7,
compress_ratio=3,
squeeze_factor=30,
conv_scale=0.01,
overlap_ratio=0.5,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
upscale=2,
img_range=1.,
upsampler='',
resi_connection='1conv',
**kwargs):
super(HAT, self).__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.overlap_ratio = overlap_ratio
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
# 相对位置索引
relative_position_index_SA = self.calculate_rpi_sa()
relative_position_index_OCA = self.calculate_rpi_oca()
self.register_buffer('relative_position_index_SA', relative_position_index_SA)
self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
# ------------------------- 1,浅层特征提取 ------------------------- #
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
# ------------------------- 2,深层特征提取 ------------------------- #
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# 将图像分割为非重叠patch
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# 将非重叠的patch合并为图像
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# 绝对位置嵌入
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# 随机深度
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 随机深度衰减规则
# 构建残差混合注意力组 (RHAG)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RHAG(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
overlap_ratio=overlap_ratio,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # 对SR结果无影响
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# 构建深层特征提取中的最后一个卷积层
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == 'identity':
self.conv_after_body = nn.Identity()
# ------------------------- 3,高质量图像重建 ------------------------- #
if self.upsampler == 'pixelshuffle':
# 用于经典SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
task与yaml配置
详见:https://blog.csdn.net/shangyanaf/article/details/139142532