语义分割数据增强——图像和标注同步增强

简介: 其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。

语义分数据增强


常见的数据增强方式


查看pytorchtorchvisiontransformer中的源代码,我们可以看到具有以下数据增强方式:


__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
           "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
           "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
           "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
           "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]


其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。


语义分割和图像分类的数据增强差异在于:语义分割是对图像的每个像素进行分类,所以在进行某些数据增强时,需要对标注图像(mask)进行同步操作,如旋转、剪裁、翻转等。


查遍网上的一些教程,但是没有发现一个能够直接使用的pytorch数据增强方式,所以想自己写一个方便后续使用。


具体实现代码


我们这里以一个细胞语义分割数据集为例,由于该数据集是灰度图像,所以相对于彩色图像数据增强有一些差距,代码中注释了灰度图像不能使用的数据增强方式,但是彩色图像可以使用的数据增强方式。具体代码如下所示:


import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision.transforms import functional as F
import random
class CellDataset(Dataset):
    def __init__(self, image_dir, mask_dir, names_list, image_size=224, isGray=False, augmentation=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.augmentation = augmentation
        self.names_list = names_list
        self.isGray = isGray
        self.image_size = image_size
    def __len__(self):
        return len(self.names_list)
    def augmentate(self, image, mask):
        # it is expected to be in [..., H, W] format
        image = torch.unsqueeze(torch.from_numpy(
            np.array(image, dtype=np.uint8)), dim=0)
        mask = torch.unsqueeze(torch.from_numpy(
            np.array(mask, dtype=np.uint8)), dim=0)
        image = F.resize(image, size=[self.image_size, self.image_size])
        mask = F.resize(mask, size=[self.image_size, self.image_size])
        # 彩色图可以进行以下数据增强,参数不太好调整
        image = F.adjust_gamma(image, gamma=random.uniform(0.8, 1.2))
        image = F.adjust_contrast(
            image, contrast_factor=random.uniform(0.8, 1.2))
        image = F.adjust_brightness(
            image, brightness_factor=random.uniform(0.8, 1.2))
        image = F.adjust_saturation(
            image, saturation_factor=random.uniform(0.8, 1.2))
        image = F.adjust_hue(image, hue_factor=random.uniform(-0.2, 0.2))
        # 让image和mask进行同步旋转和翻转数据增强
        image_mask = torch.cat([image, mask], dim=0)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.hflip(image_mask)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.vflip(image_mask)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.rotate(image_mask, angle=90)
        # 要看image和mask的维度
        image = image_mask[0, ...]
        mask = image_mask[1, ...]
        # image = image / 255
        # mask = mask / 255
        # image = torch.unsqueeze(image, dim=0)
        # 标准化,彩色图像需要传三个值
        # image = F.normalize(image, mean=[0.5], std=[1])
        # mask = torch.unsqueeze(mask, dim=0)
        return image, mask
    def __getitem__(self, item):
        image_path = os.path.join(self.image_dir, self.names_list[item])
        mask_path = os.path.join(self.mask_dir, self.names_list[item])
        image = Image.open(image_path)
        if self.isGray:
            image = image.convert('L')
        mask = Image.open(mask_path)
        if self.augmentation:
            image, mask = self.augmentate(image, mask)
        return image, mask
if __name__ == '__main__':
    cell_dataset = CellDataset(image_dir='./data/image', mask_dir='./data/label',
                               names_list=['0.png'])
    index = 3
    for image, mask in cell_dataset:
        print(image.shape, mask.shape)
        print(torch.max(image), torch.min(image))
        image = np.array(image, dtype=np.uint8)
        mask = np.array(mask, dtype=np.uint8)
        # cv2.imshow('image', image)
        # cv2.imshow('mask', mask)
        # cv2.waitKey(0)
        cv2.imwrite(os.path.join(
            './data/augment/image', str(index)+'.png'), image)
        cv2.imwrite(os.path.join(
            './data/augment/label', str(index)+'.png'), mask)


我们以下图图像和图像标注掩码为例进行实验:


b732abd841334d4d8fc8a20799ba0a0e.png


原始细胞图像


05d151dfed7048228c0c101814e7fc52.png


原始细胞掩码标签图像


我们更该index的值,进行重复执行程序,来生成多个不同对应的数据增强图像,我们重复执行了4次,得到了以下数据增强的细胞图像和对应的掩码标签图像。


d6dbef4a943c49cb88412d7224da4d93.png

7960d8730d094668897fe8216e0ff771.png


如果有问题可以在评论区进行回复。如果对您有帮助的话可以帮忙点赞👍👍👍。

目录
相关文章
|
Linux 网络安全 开发工具
校外网络连接校园网内的linux服务器方法(使用frp实现内网穿透)
平常在校园里连接校内实验室的linux服务器可以直接使用ssh直接链接私有ip地址,一旦本地移动到了校园网外部(如:使用手机流量wifi,或着暑假回家使用家庭wifi)便无法在使用ssh连接校内的服务器。本文提供一个实现校外也能访问校内服务器的方法
7003 0
校外网络连接校园网内的linux服务器方法(使用frp实现内网穿透)
|
机器学习/深度学习 计算机视觉
Mobile-Unet网络综述
Mobile-Unet网络综述
3073 0
Mobile-Unet网络综述
|
自然语言处理 算法 数据挖掘
自蒸馏:一种简单高效的优化方式
背景知识蒸馏(knowledge distillation)指的是将预训练好的教师模型的知识通过蒸馏的方式迁移至学生模型,一般来说,教师模型会比学生模型网络容量更大,模型结构更复杂。对于学生而言,主要增益信息来自于更强的模型产出的带有更多可信信息的soft_label。例如下右图中,两个“2”对应的hard_label都是一样的,即0-9分类中,仅“2”类别对应概率为1.0,而soft_label
自蒸馏:一种简单高效的优化方式
|
PyTorch 算法框架/工具
MMsegmentation教程 4: 自定义模型
MMsegmentation教程 4: 自定义模型
1037 0
|
机器学习/深度学习 并行计算 异构计算
NVIDIA CUDA/cuDNN历代版本下载地址
NVIDIA CUDA/cuDNN历代版本下载地址
5419 0
NVIDIA CUDA/cuDNN历代版本下载地址
|
机器学习/深度学习 编解码 计算机视觉
YOLOv11改进策略【Head】| ASFF 自适应空间特征融合模块,改进检测头Detect_ASFF
YOLOv11改进策略【Head】| ASFF 自适应空间特征融合模块,改进检测头Detect_ASFF
2382 13
YOLOv11改进策略【Head】| ASFF 自适应空间特征融合模块,改进检测头Detect_ASFF
|
机器学习/深度学习 JSON 算法
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
本文介绍了DeepLab V3在语义分割中的应用,包括数据集准备、模型训练、测试和评估,提供了代码和资源链接。
4579 0
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
|
机器学习/深度学习 自然语言处理 算法
深度学习基础知识:介绍深度学习的发展历程、基本概念和主要应用
深度学习基础知识:介绍深度学习的发展历程、基本概念和主要应用
7399 0
|
机器学习/深度学习 人工智能 编解码
【AI系统】MobileNet 系列
本文详细介绍 MobileNet 系列模型,重点探讨其轻量化设计原则。从 MobileNetV1 开始,通过深度可分离卷积和宽度乘数减少参数量,实现低延迟、低功耗。后续版本 V2、V3、V4 逐步引入线性瓶颈、逆残差、Squeeze-and-Excitation 模块、新型激活函数 h-swish、NAS 搜索等技术,持续优化性能。特别是 MobileNetV4,通过通用倒瓶颈(UIB)和 Mobile MQA 技术,大幅提升模型效率,达到硬件无关的 Pareto 最优。文章结合最新深度学习技术,全面解析各版本的改进与设计思路。
5025 8
|
机器学习/深度学习 数据可视化 自动驾驶
YOLO11-seg分割如何训练自己的数据集(道路缺陷)
本文介绍了如何使用自己的道路缺陷数据集训练YOLOv11-seg模型,涵盖数据集准备、模型配置、训练过程及结果可视化。数据集包含4029张图像,分为训练、验证和测试集。训练后,模型在Mask mAP50指标上达到0.673,展示了良好的分割性能。
6776 4