语义分割数据增强
常见的数据增强方式
查看pytorch
torchvision的transformer中的源代码,我们可以看到具有以下数据增强方式:
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。
语义分割和图像分类的数据增强差异在于:语义分割是对图像的每个像素进行分类,所以在进行某些数据增强时,需要对标注图像(mask)进行同步操作,如旋转、剪裁、翻转等。
查遍网上的一些教程,但是没有发现一个能够直接使用的pytorch数据增强方式,所以想自己写一个方便后续使用。
具体实现代码
我们这里以一个细胞语义分割数据集为例,由于该数据集是灰度图像,所以相对于彩色图像数据增强有一些差距,代码中注释了灰度图像不能使用的数据增强方式,但是彩色图像可以使用的数据增强方式。具体代码如下所示:
import numpy as np import cv2 import torch from torch.utils.data import Dataset import os from PIL import Image from torchvision.transforms import functional as F import random class CellDataset(Dataset): def __init__(self, image_dir, mask_dir, names_list, image_size=224, isGray=False, augmentation=True): self.image_dir = image_dir self.mask_dir = mask_dir self.augmentation = augmentation self.names_list = names_list self.isGray = isGray self.image_size = image_size def __len__(self): return len(self.names_list) def augmentate(self, image, mask): # it is expected to be in [..., H, W] format image = torch.unsqueeze(torch.from_numpy( np.array(image, dtype=np.uint8)), dim=0) mask = torch.unsqueeze(torch.from_numpy( np.array(mask, dtype=np.uint8)), dim=0) image = F.resize(image, size=[self.image_size, self.image_size]) mask = F.resize(mask, size=[self.image_size, self.image_size]) # 彩色图可以进行以下数据增强,参数不太好调整 image = F.adjust_gamma(image, gamma=random.uniform(0.8, 1.2)) image = F.adjust_contrast( image, contrast_factor=random.uniform(0.8, 1.2)) image = F.adjust_brightness( image, brightness_factor=random.uniform(0.8, 1.2)) image = F.adjust_saturation( image, saturation_factor=random.uniform(0.8, 1.2)) image = F.adjust_hue(image, hue_factor=random.uniform(-0.2, 0.2)) # 让image和mask进行同步旋转和翻转数据增强 image_mask = torch.cat([image, mask], dim=0) if random.uniform(0, 1) > 0.5: image_mask = F.hflip(image_mask) if random.uniform(0, 1) > 0.5: image_mask = F.vflip(image_mask) if random.uniform(0, 1) > 0.5: image_mask = F.rotate(image_mask, angle=90) # 要看image和mask的维度 image = image_mask[0, ...] mask = image_mask[1, ...] # image = image / 255 # mask = mask / 255 # image = torch.unsqueeze(image, dim=0) # 标准化,彩色图像需要传三个值 # image = F.normalize(image, mean=[0.5], std=[1]) # mask = torch.unsqueeze(mask, dim=0) return image, mask def __getitem__(self, item): image_path = os.path.join(self.image_dir, self.names_list[item]) mask_path = os.path.join(self.mask_dir, self.names_list[item]) image = Image.open(image_path) if self.isGray: image = image.convert('L') mask = Image.open(mask_path) if self.augmentation: image, mask = self.augmentate(image, mask) return image, mask if __name__ == '__main__': cell_dataset = CellDataset(image_dir='./data/image', mask_dir='./data/label', names_list=['0.png']) index = 3 for image, mask in cell_dataset: print(image.shape, mask.shape) print(torch.max(image), torch.min(image)) image = np.array(image, dtype=np.uint8) mask = np.array(mask, dtype=np.uint8) # cv2.imshow('image', image) # cv2.imshow('mask', mask) # cv2.waitKey(0) cv2.imwrite(os.path.join( './data/augment/image', str(index)+'.png'), image) cv2.imwrite(os.path.join( './data/augment/label', str(index)+'.png'), mask)
我们以下图图像和图像标注掩码为例进行实验:
原始细胞图像
原始细胞掩码标签图像
我们更该index的值,进行重复执行程序,来生成多个不同对应的数据增强图像,我们重复执行了4次,得到了以下数据增强的细胞图像和对应的掩码标签图像。
如果有问题可以在评论区进行回复。如果对您有帮助的话可以帮忙点赞👍👍👍。