语义分割数据增强——图像和标注同步增强

简介: 其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。

语义分数据增强


常见的数据增强方式


查看pytorchtorchvisiontransformer中的源代码,我们可以看到具有以下数据增强方式:


__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
           "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
           "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
           "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
           "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]


其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。


语义分割和图像分类的数据增强差异在于:语义分割是对图像的每个像素进行分类,所以在进行某些数据增强时,需要对标注图像(mask)进行同步操作,如旋转、剪裁、翻转等。


查遍网上的一些教程,但是没有发现一个能够直接使用的pytorch数据增强方式,所以想自己写一个方便后续使用。


具体实现代码


我们这里以一个细胞语义分割数据集为例,由于该数据集是灰度图像,所以相对于彩色图像数据增强有一些差距,代码中注释了灰度图像不能使用的数据增强方式,但是彩色图像可以使用的数据增强方式。具体代码如下所示:


import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision.transforms import functional as F
import random
class CellDataset(Dataset):
    def __init__(self, image_dir, mask_dir, names_list, image_size=224, isGray=False, augmentation=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.augmentation = augmentation
        self.names_list = names_list
        self.isGray = isGray
        self.image_size = image_size
    def __len__(self):
        return len(self.names_list)
    def augmentate(self, image, mask):
        # it is expected to be in [..., H, W] format
        image = torch.unsqueeze(torch.from_numpy(
            np.array(image, dtype=np.uint8)), dim=0)
        mask = torch.unsqueeze(torch.from_numpy(
            np.array(mask, dtype=np.uint8)), dim=0)
        image = F.resize(image, size=[self.image_size, self.image_size])
        mask = F.resize(mask, size=[self.image_size, self.image_size])
        # 彩色图可以进行以下数据增强,参数不太好调整
        image = F.adjust_gamma(image, gamma=random.uniform(0.8, 1.2))
        image = F.adjust_contrast(
            image, contrast_factor=random.uniform(0.8, 1.2))
        image = F.adjust_brightness(
            image, brightness_factor=random.uniform(0.8, 1.2))
        image = F.adjust_saturation(
            image, saturation_factor=random.uniform(0.8, 1.2))
        image = F.adjust_hue(image, hue_factor=random.uniform(-0.2, 0.2))
        # 让image和mask进行同步旋转和翻转数据增强
        image_mask = torch.cat([image, mask], dim=0)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.hflip(image_mask)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.vflip(image_mask)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.rotate(image_mask, angle=90)
        # 要看image和mask的维度
        image = image_mask[0, ...]
        mask = image_mask[1, ...]
        # image = image / 255
        # mask = mask / 255
        # image = torch.unsqueeze(image, dim=0)
        # 标准化,彩色图像需要传三个值
        # image = F.normalize(image, mean=[0.5], std=[1])
        # mask = torch.unsqueeze(mask, dim=0)
        return image, mask
    def __getitem__(self, item):
        image_path = os.path.join(self.image_dir, self.names_list[item])
        mask_path = os.path.join(self.mask_dir, self.names_list[item])
        image = Image.open(image_path)
        if self.isGray:
            image = image.convert('L')
        mask = Image.open(mask_path)
        if self.augmentation:
            image, mask = self.augmentate(image, mask)
        return image, mask
if __name__ == '__main__':
    cell_dataset = CellDataset(image_dir='./data/image', mask_dir='./data/label',
                               names_list=['0.png'])
    index = 3
    for image, mask in cell_dataset:
        print(image.shape, mask.shape)
        print(torch.max(image), torch.min(image))
        image = np.array(image, dtype=np.uint8)
        mask = np.array(mask, dtype=np.uint8)
        # cv2.imshow('image', image)
        # cv2.imshow('mask', mask)
        # cv2.waitKey(0)
        cv2.imwrite(os.path.join(
            './data/augment/image', str(index)+'.png'), image)
        cv2.imwrite(os.path.join(
            './data/augment/label', str(index)+'.png'), mask)


我们以下图图像和图像标注掩码为例进行实验:


b732abd841334d4d8fc8a20799ba0a0e.png


原始细胞图像


05d151dfed7048228c0c101814e7fc52.png


原始细胞掩码标签图像


我们更该index的值,进行重复执行程序,来生成多个不同对应的数据增强图像,我们重复执行了4次,得到了以下数据增强的细胞图像和对应的掩码标签图像。


d6dbef4a943c49cb88412d7224da4d93.png

7960d8730d094668897fe8216e0ff771.png


如果有问题可以在评论区进行回复。如果对您有帮助的话可以帮忙点赞👍👍👍。

目录
打赏
0
0
0
0
3
分享
相关文章
浅述几种文本和图像数据增强的方法
在现实场景中,我们往往收集不到太多的数据,那么为了扩大数据集,可以采用数据增强手段来增加样本,那么平常我们应该怎么做数据增强的呢? 什么是数据增强 数据增强也叫数据扩增,意思是在不实质性的增加数据的情况下,让有限的数据产生等价于更多数据的价值。
【论文速递】TMM2023 - FECANet:用特征增强的上下文感知网络增强小样本语义分割
【论文速递】TMM2023 - FECANet:用特征增强的上下文感知网络增强小样本语义分割
SPRIGHT:提升文本到图像模型空间一致性的数据集
SPRIGHT 是一个专注于空间关系的大型视觉-语言数据集,通过重新描述600万张图像,显著提升文本到图像模型的空间一致性。
70 18
SPRIGHT:提升文本到图像模型空间一致性的数据集
图像数据的特征提取与预处理方法,涵盖图像数据的特点、主要的特征提取技术
本文深入探讨了图像数据的特征提取与预处理方法,涵盖图像数据的特点、主要的特征提取技术(如颜色、纹理、形状特征)及预处理步骤(如图像增强、去噪、分割)。同时介绍了Python中常用的OpenCV和Scikit-image库,并提供了代码示例,强调了预处理的重要性及其在提升模型性能中的作用。
513 5
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
本文介绍了DeepLab V3在语义分割中的应用,包括数据集准备、模型训练、测试和评估,提供了代码和资源链接。
800 0
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
|
5月前
|
遥感语义分割数据集中的切图策略
该脚本用于遥感图像的切图处理,支持大尺寸图像按指定大小和步长切割为多个小图,适用于语义分割任务的数据预处理。通过设置剪裁尺寸(cs)和步长(ss),可灵活调整输出图像的数量和大小。此外,脚本还支持标签图像的转换,便于后续模型训练使用。
45 0
使用Python实现深度学习模型:图像语义分割与对象检测
【7月更文挑战第15天】 使用Python实现深度学习模型:图像语义分割与对象检测
139 2
将图像自动文本化,图像描述质量更高、更准确了
【7月更文挑战第11天】AI研究提升图像文本化准确性:新框架IT融合多模态大模型与视觉专家,生成详细无幻觉的图像描述。通过三个阶段—全局文本化、视觉细节提取和重描述,实现更高质量的图像转文本。研究人员建立DID-Bench、D2I-Bench和LIN-Bench基准,展示描述质量显著提升。尽管有进步,仍面临幻觉、细节缺失及大规模处理挑战。[论文链接](https://arxiv.org/pdf/2406.07502v1)**
61 1
TabR:检索增强能否让深度学习在表格数据上超过梯度增强模型?
这是一篇7月新发布的论文,他提出了使用自然语言处理的检索增强*Retrieval Augmented*技术,目的是让深度学习在表格数据上超过梯度增强模型。
168 0
MaskFormer:将语义分割和实例分割作为同一任务进行训练
目标检测和实例分割是计算机视觉的基本任务,在从自动驾驶到医学成像的无数应用中发挥着关键作用。目标检测的传统方法中通常利用边界框技术进行对象定位,然后利用逐像素分类为这些本地化实例分配类。但是当处理同一类的重叠对象时,或者在每个图像的对象数量不同的情况下,这些方法通常会出现问题。
4897 0