YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样

简介: 这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。

1. 图片采样策略想法


  • 图片采样策略想法

在我们训练数据集的时候,一般是对数据集随机采样几张图像然后构建成一个mini-batch来批量输入网络处理。个人猜想,一个可能的想法就是,这种随机的图像采集会不会过于随意,因为有些图像的目标是过少的,那么这种图像可能对网络来说比较简单;而有些图像的目标是比较多的,这种是比较困难的。而对于开始训练的初期就使用这种简答图像对网络的训练可能带来不了多大的学习提升。


所以,如果可以对数据集中的每张图像做一个权重的划分,在训练模型的时候依照图像的权重大小依次按难到易的大概顺序来进行训练,让模型从一开始的困难的样本较快的学习到潜在特征,到之后通过简单的图像样本来对参数进行微调,说不定是一个好的方法。


(以上内容是个人的思考猜测,可能是有误的,欢迎探讨。)


  • 图片采样策略思路

那么具体的实现思路就是,对整个数据集的图像目标做类别统计,然后类别的数目越大权重越小(成反比的关系)。然后再使用整个数据集的类别权重对每一张图像做类别权重的叠加。也就是根据每一张的图片的类别权重和来作为采样的权重,决定其采用的顺序。在代码的实现中是从大到小排序的。


2. 图片采样策略代码


yolov5参考代码

大概的注释都写在代码里了:


def train():
  ...
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
  ...
  for epoch in range(start_epoch, epochs): 
  model.train()
  # Update image weights (optional, single-GPU only)
        if opt.image_weights:
            # 根据数据集的类别数目构建每个类别的权重(类别权重与类别数目成反比)
            cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
            # 对每张图片的目标计算其类别权重和作为图片的采集权重
            iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
            # 再更具每张图片的采集权重来构建图片的采样顺序
            dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
  ...
def labels_to_class_weights(labels, nc=80):
    # Get class weights (inverse frequency) from training labels
    # labels是当前数据集的训练集的所有图像: {list: 682}
    # 列表的每个对象格式是: (ndarray: (k, 5)) k表示当前图像的目表个数, 5是(class+xywh)
    if labels[0] is None:  # no labels loaded
        return torch.Tensor()
    # 把图像的标签列表直接转化为标签列表:{ndarray: (labels, 5)} labels表示全部图像的所有标签个数
    labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
    # 提取类别 labels[:, 0] 数据来为每一类做统计 .astype(np.int): 取整
    classes = labels[:, 0].astype(np.int)  # labels = [class xywh]
    # weight: 统计每个类别出现的次数
    weights = np.bincount(classes, minlength=nc)  # occurrences per class
    # Prepend gridpoint count (for uCE training)
    # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
    # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start
    # 将出现次数为0的类别权重全部取1
    weights[weights == 0] = 1  # replace empty bins with 1
    # 类别权重取类别出现次数的倒数, 也就是表示类别次数与权重成反比, 标签频率越高的类别权重越低, 因为越不罕见
    weights = 1 / weights  # number of targets per class
    # 归一化操作: 求出每一类别的占比
    weights /= weights.sum()  # normalize
    return torch.from_numpy(weights)  # numpy -> tensor
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
    # Produces image weights based on class_weights and image contents
    # out:{ndarray: (682,3)} 统计每一张图片中类类别的数目 这里我用的是mask数据集有3个类别 每个位置存储图像中对应类别目标出现的个数
    class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
    # class_weights:[n_class] -> [1, n_class]
    # 每张图片的每个类别个数[label_nums, n_class] * 整个数据集每个类别的权重[1, n_class] = 每张图片的对应每个类别的权重[label_nums, n_class_weight]
    # 然后每个类别的权重加在一起等于当前这张图片的权重
    image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
    # index = random.choices(range(n), weights=image_weights, k=1)  # weight image sample
    return image_weights


构造Dataset使用的地方

class LoadImagesAndLabels(Dataset):
  def __init__(self, img_size=640, batch_size=16, image_weights=False, ...):
  ...
  self.indices = range(n)
  def __len__(self):
        return len(self.img_files)
    def __getitem__(self, index):
      # 重点使用部分, 就是用权重采样策略替代了随机采样
      # 随机采样: index返回的是随机值(shuffle = True),所以注意到其实在
      # 权重采样: index是按顺序从0开始, 然后依次提取indices所指向的图像索引
        index = self.indices[index]  # linear, shuffled, or image_weights
        img, labels = load_mosaic(self, index)
        ...
        return torch.from_numpy(img), labels_out, self.img_files[index], shapes
# 因为可以注意到, 构建dataloader的时候yolov5代码中是没有使用shuffle=True这个随机采样的参数的
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
                      rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
    # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
    with torch_distributed_zero_first(rank):
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyperparameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)
    batch_size = min(batch_size, len(dataset))
    # 这里对num_worker进行更改
    # nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers])  # number of workers
    nw = 0  # 可以适当提高这个参数0, 2, 4, 8, 16…
    sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
    loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
    # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
    # 没有使用 shuffle=True 这个参数
  dataloader = loader(dataset,
                        batch_size=batch_size,
                        num_workers=nw,
                        sampler=sampler,
                        pin_memory=True,
                        collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    return dataloader, dataset


所以从代码中可以看见,如果不使用图像采样策略,这里也不会使用随机的选择策略,而且index从0开始提取,验证如下:


第一次断点调试:index从0开始,想法验证成功

image.png


参考资料:

1. 【YOLOV5-5.x 源码解读】general.py


目录
相关文章
|
算法 计算机视觉 异构计算
目标检测的Tricks | 【Trick7】数据增强——Mosaic(马赛克)
目标检测的Tricks | 【Trick7】数据增强——Mosaic(马赛克)
2143 0
目标检测的Tricks | 【Trick7】数据增强——Mosaic(马赛克)
|
1月前
|
机器学习/深度学习 JSON 数据可视化
YOLO11-pose关键点检测:训练实战篇 | 自己数据集从labelme标注到生成yolo格式的关键点数据以及训练教程
本文介绍了如何将个人数据集转换为YOLO11-pose所需的数据格式,并详细讲解了手部关键点检测的训练过程。内容涵盖数据集标注、格式转换、配置文件修改及训练参数设置,最终展示了训练结果和预测效果。适用于需要进行关键点检测的研究人员和开发者。
201 0
|
6月前
|
机器学习/深度学习 存储 数据可视化
R语言混合效应逻辑回归(mixed effects logistic)模型分析肺癌数据
R语言混合效应逻辑回归(mixed effects logistic)模型分析肺癌数据
|
6月前
|
编解码 算法 知识图谱
ICCV 2023 | DAT:利用双重聚合的Transformer进行图像超分
ICCV 2023 | DAT:利用双重聚合的Transformer进行图像超分
153 0
|
6月前
|
机器学习/深度学习 缓存 测试技术
Nice Trick | 不想标注数据了!有伪标签何必呢,Mixup+Mosaic让DINO方法再继续涨点
Nice Trick | 不想标注数据了!有伪标签何必呢,Mixup+Mosaic让DINO方法再继续涨点
209 0
|
6月前
|
机器学习/深度学习 5G 知识图谱
视觉Backbone怎么使用1/8的FLOPs实现比Baseline更高的精度?
视觉Backbone怎么使用1/8的FLOPs实现比Baseline更高的精度?
70 0
|
6月前
|
存储 数据可视化 计算机视觉
基于YOLOv8的自定义数据姿势估计
基于YOLOv8的自定义数据姿势估计
|
机器学习/深度学习 Serverless 计算机视觉
NeRF 模型评价指标PSNR,MS-SSIM, LPIPS 详解和python实现
NeRF 模型评价指标PSNR,MS-SSIM, LPIPS 详解和python实现
2460 0
|
算法 数据挖掘
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(二)
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(二)
166 1
|
机器学习/深度学习 算法 前端开发
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(一)
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(一)
236 1