YOLOv5的Tricks | 【Trick3】Test Time Augmentation(TTA)

简介: 一句话简单的介绍Test Time Augmentation(TTA)就是测试过程中也使用数据增强,官方教程介绍:Test-Time Augmentation (TTA) Tutorial

1. TTA概念介绍


在训练过程中数据增强是非常常用的一种手段,目的是为了提高模型的泛化能力,以免出现大小不一样,图像选择一下就分辨不出来的尴尬。那么TTA就是想在推理阶段也进行数据增强。不过不会太复杂,因为会增加额外的计算量,在打比赛的时候可能会用到,因为打比赛不在意你的推理时长是多久,所以可以尽情瞎造;但是在实际部署的情况下,因为推理速度减慢很可能会达不到实时监测的效果,所以实际是没有必要在推理也进行数据增强的,会降低监测速度。


2. TTA代码实现


知道了其原理是在推理阶段使用数据增强,那么很明显,其将在model中的前向传播过程中实现。在yolov5中,TTA 自动集成到所有YOLOv5 PyTorch Hub模型中。具体的解析我已经写在了注释中。


yolov5实现代码:


class Model(nn.Module):
    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
        super().__init__()
        ...
        # 如果直接传入的是dict则无需处理; 如果不是则使用yaml.safe_load加载yaml文件
        with open(cfg, errors='ignore') as f:
            self.yaml = yaml.safe_load(f)  # model dict
  ...
        # 创建网络模型
        # self.model: 初始化的整个网络模型(包括Detect层结构)
        # self.save: 所有层结构中from不等于-1的序号,并排好序  [4, 6, 10, 14, 17, 20, 23]
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
        ...
     def forward(self, x, augment=False, profile=False, visualize=False):  # debug同样需要第三次才能正常跳进来
        if augment:     # use Test Time Augmentation(TTA), 如果打开会对图片进行scale和flip
            return self._forward_augment(x)  # augmented inference, None
        return self._forward_once(x, profile, visualize)  # single-scale inference, train
  # 使用TTA进行推理(当然还是会调用普通推理实现前向传播)
    def _forward_augment(self, x):
        img_size = x.shape[-2:]  # height, width
        s = [1, 0.83, 0.67]  # scales
        f = [None, 3, None]  # flips (2-ud上下flip, 3-lr左右flip)
        y = []  # outputs
        # 这里相当于对输入x进行3次不同参数的测试数据增强推理, 每次的推理结构都保存在列表y中
        for si, fi in zip(s, f):
            # scale_img缩放图片尺寸
            # 通过普通的双线性插值实现,根据ratio来控制图片的缩放比例,最后通过pad 0补齐到原图的尺寸
            xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
            yi = self._forward_once(xi)[0]  # forward:torch.Size([1, 25200, 25])
            # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
            # _descale_pred将推理结果恢复到相对原图图片尺寸, 只对坐标xywh:yi[..., :4]进行恢复
            # 如果f=2,进行上下翻转; 如果f=3,进行左右翻转
            yi = self._descale_pred(yi, fi, si, img_size)
            y.append(yi)    # [b, 25200, 25] / [b, 18207, 25] / [b, 12348, 25]
        # 把第一层的后面一部分的预测结果去掉, 也把最后一层的前面一部分的预测结果去掉
        # [b, 24000, 25] / [b, 18207, 25] / [b, 2940, 25]
        # 筛除的可能是重复的部分吧, 提高运行速度(有了解的朋友请告诉我一下)
        y = self._clip_augmented(y)  # clip augmented tails
        return torch.cat(y, 1), None  # augmented inference, train
  # 普通推理
    def _forward_once(self, x, profile=False, visualize=False):
        # y列表用来保存中间特征图; dt用来记录每个模块执行10次的平均时长
        y, dt = [], []  # outputs
        # 对sequence模型进行遍历操作, 不断地对输入x进行处理, 中间结果需要保存的时候另外存储到列表y中
        for m in self.model:
            # 如果只是对前一个模块的输出进行操作, 则需要提取直接保存的中间特征图进行操作,
            # 一般是concat处理, 对当前层与之前曾进行一个concat再卷积; detect模块也需要提取3个特征层来处理
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            # profile参数打开会记录每个模块的平均执行10次的时长和flops用于分析模型的瓶颈, 提高模型的执行速度和降低显存占用
            if profile:
                self._profile_one_layer(m, x, dt)
            # 使用当前模块对特征图进行处理
            # 如果是concat模块: 则x是一个特征图列表, 则对其进行拼接处理, 再交给下一个卷积模块;
            # 如果是C3, Conv等普通的模块: 则x是单一特征图
            # 如果是detct模块: 则x是3个特征图的列表 (训练与推理返回的内容不一样)
            x = m(x)  # run
            # self.save: 把所有层结构中from不是-1的值记下并排序 [4, 6, 10, 14, 17, 20, 23]
            y.append(x if m.i in self.save else None)  # save output
            # 特征可视化
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x
  # 翻转数据增强
  def _descale_pred(self, p, flips, scale, img_size):
        # de-scale predictions following augmented inference (inverse operation)
        if self.inplace:
            p[..., :4] /= scale  # de-scale xywh坐标缩放回原来大小
            # f=2,进行上下翻转
            if flips == 2:
                p[..., 1] = img_size[0] - p[..., 1]  # de-flip ud
            # f=3,进行左右翻转
            elif flips == 3:
                p[..., 0] = img_size[1] - p[..., 0]  # de-flip lr
        else:
            x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale  # de-scale
            if flips == 2:
                y = img_size[0] - y  # de-flip ud
            elif flips == 3:
                x = img_size[1] - x  # de-flip lr
            p = torch.cat((x, y, wh, p[..., 4:]), -1)
        return p
    # 这里y的一个包含3个子列表的列表, 通过对输入图像x进行了3次不同尺度的变换, 所以得到了3个inference结构
    # 这里看不太懂, 不过大概做的事情就是对第一个列表与最后一个列表的结果做一些过滤处理
    # 把第一层的后面一部分的预测结果去掉, 也把最后一层的前面一部分的预测结果去掉, 然后剩下的concat为一个部分
    def _clip_augmented(self, y):
        # Clip YOLOv5 augmented inference tails
        nl = self.model[-1].nl  # Detect(): number of detection layers (P3-P5)
        g = sum(4 ** x for x in range(nl))  # grid points
        e = 1  # exclude layer count
        i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e))  # indices: (25200 // 21) * 1 = 1200
        y[0] = y[0][:, :-i]  # large: (1,25200,25) -> (1,24000,25)
        i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indices: (12348 // 21) * 16 = 9408
        y[-1] = y[-1][:, i:]  # small: (1,12348,25) -> (1,2940,25)
        return y
  ...


这里部分的函数还会调用torch_utils来实现,比如scale_img通过双线性插值来实现图像的缩放(在通过pad0来补齐到原图的尺寸),这里额外贴上scale_img这个辅助函数:


scale_img函数:

# 通过普通的双线性插值实现,根据ratio来控制图片的缩放比例,最后通过pad 0补齐到原图的尺寸
def scale_img(img, ratio=1.0, same_shape=False, gs=32):  # img(16,3,256,416)
    # scales img(bs,3,y,x) by ratio constrained to gs-multiple
    if ratio == 1.0:
        return img
    else:
        h, w = img.shape[2:]
        s = (int(h * ratio), int(w * ratio))  # new size
        img = F.interpolate(img, size=s, mode='bilinear', align_corners=False)  # resize
        if not same_shape:  # pad/crop img
            h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
        return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean


代码+注释有点长,具体就是通过双线性插值对图像进行缩放然后打上补丁到原图大小,推理完获取结果后再将推理结果恢复到相对原图图片尺寸, 只对坐标xywh:yi[…, :4]进行恢复。


3. TTA使用方法


官方教程链接:https://github.com/ultralytics/yolov5/issues/303


Test with TTA

python val.py --weights yolov5x.pt --data coco.yaml --img 832 --augment


Inference with TTA

python detect.py --weights yolov5s.pt --img 832 --source data/images --augment


官方文档指出,启用 TTA 的推理通常需要大约 2-3 倍的正常推理时间,因为图像正在左右翻转并以 3 种不同的分辨率进行处理,输出在 NMS 之前合并。同时这里为了避免太多的冗余结果,还已经通过了_clip_augmented函数去除了部分结果,然后再对3个不同分辨率处理后输出的结果进行合并再给nms进行后处理。速度下降的部分原因仅仅是图像尺寸较大(832 对 640),而部分原因是实际的 TTA 操作。因为,很显然,本来一个输入得多一个推理结果,而使用了TTA不仅使用了数据增强上下左右翻转(在源码中只对中间尺度进行左右翻转),同时还进行了多个尺度的输入处理获取了多个尺度的输出结果,这显然增加了3倍的工作量,所以启用 TTA 的推理通常需要大约 2-3 倍的正常推理时间,这是可以理解的。


此外,可以看见,其实使用TTA这个功能只需要设置augment这个参数,而这个参数在forward中进行控制。测试代码如下所示:


import torch
# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')  # or yolov5m, yolov5x, custom
# Images
img = 'https://ultralytics.com/images/zidane.jpg'  # or file, PIL, OpenCV, numpy, multiple
# Inference
results = model(img, augment=True)  # <--- TTA inference
# Results
results.print()  # or .show(), .save(), .crop(), .pandas(), etc.


此外,既然了解了原理,那么也可以自定义自己的TTA,可以执行设置自己推理阶段的数据增强,也就是需要改写_forward_augment函数即可:


def _forward_augment(self, x):
  img_size = x.shape[-2:]  # height, width
  s = [1, 0.83, 0.67]  # scales
  f = [None, 3, None]  # flips (2-ud上下flip, 3-lr左右flip)
  y = []  # outputs
  # 这里相当于对输入x进行3次不同参数的测试数据增强推理, 每次的推理结构都保存在列表y中
  for si, fi in zip(s, f):
     # scale_img缩放图片尺寸
     # 通过普通的双线性插值实现,根据ratio来控制图片的缩放比例,最后通过pad 0补齐到原图的尺寸
     xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
     yi = self._forward_once(xi)[0]  # forward:torch.Size([1, 25200, 25])
     # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
     # _descale_pred将推理结果恢复到相对原图图片尺寸, 只对坐标xywh:yi[..., :4]进行恢复
     # 如果f=2,进行上下翻转; 如果f=3,进行左右翻转
     yi = self._descale_pred(yi, fi, si, img_size)
     y.append(yi)    # [b, 25200, 25] / [b, 18207, 25] / [b, 12348, 25]
  y = self._clip_augmented(y)  # clip augmented tails
  return torch.cat(y, 1), None  # augmented inference, train


总结yolov5实现的TTA手段:

  • 增大输入图像大小30%,如640 vs 832
  • left-right flipped
  • 3 different resolutions
  • the outputs merged before NMS


结果:和正常比慢 2-3X


参考资料:

1. Test-Time Augmentation (TTA) Tutorial

2. YOLOv5 Test-Time Augmentation (TTA) 教程:这篇是将Tutorial搬运下来的,所以本质上两个内容一样


目录
相关文章
|
8月前
|
机器学习/深度学习 编解码 自然语言处理
DeIT:Training data-efficient image transformers & distillation through attention论文解读
最近,基于注意力的神经网络被证明可以解决图像理解任务,如图像分类。这些高性能的vision transformer使用大量的计算资源来预训练了数亿张图像,从而限制了它们的应用。
237 0
|
8月前
|
机器学习/深度学习 自然语言处理 算法
SS-AGA:Multilingual Knowledge Graph Completion with Self-Supervised Adaptive Graph Alignment 论文解读
预测知识图(KG)中缺失的事实是至关重要的,因为现代知识图远未补全。由于劳动密集型的人类标签,当处理以各种语言表示的知识时,这种现象会恶化。
57 0
|
计算机视觉 索引
YOLOv5的Tricks | 【Trick14】YOLOv5的val.py脚本的解析
YOLOv5的Tricks | 【Trick14】YOLOv5的val.py脚本的解析
1042 0
YOLOv5的Tricks | 【Trick14】YOLOv5的val.py脚本的解析
|
10月前
|
数据可视化 数据挖掘
【论文解读】Dual Contrastive Learning:Text Classification via Label-Aware Data Augmentation
北航出了一篇比较有意思的文章,使用标签感知的数据增强方式,将对比学习放置在有监督的环境中 ,下游任务为多类文本分类,在低资源环境中进行实验取得了不错的效果
242 0
|
11月前
|
机器学习/深度学习 存储 机器人
LF-YOLO: A Lighter and Faster YOLO for Weld Defect Detection of X-ray Image
高效的特征提取EFE模块作为主干单元,它可以用很少的参数和低计算量提取有意义的特征,有效地学习表征。大大减少了特征提取的消耗
94 0
|
机器学习/深度学习 编解码 固态存储
Single Shot MultiBox Detector论文翻译【修改】
Single Shot MultiBox Detector论文翻译【修改】
74 0
Single Shot MultiBox Detector论文翻译【修改】
|
机器学习/深度学习 算法 数据挖掘
【论文泛读】 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
【论文泛读】 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
【论文泛读】 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
|
计算机视觉
目标检测的Tricks | 【Trick1】Label Smoothing
目标检测的Tricks | 【Trick1】Label Smoothing
133 0
|
计算机视觉
目标检测的Tricks | 【Trick4】Multi-scale training与Multi-scale testing
目标检测的Tricks | 【Trick4】Multi-scale training与Multi-scale testing
207 0
《Investigation of Transformer based Spelling Correction Model for CTC-based End-to-End Mandarin Speech Recognition》电子版地址
Investigation of Transformer based Spelling Correction Model for CTC-based End-to-End Mandarin Speech Recognition
76 0
《Investigation of Transformer based Spelling Correction Model for CTC-based End-to-End Mandarin Speech Recognition》电子版地址