【单点知识】基于实例讲解PyTorch中的transforms类

简介: 【单点知识】基于实例讲解PyTorch中的transforms类

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。


PyTorch框架中,torchvision.transforms 模块提供了一系列用于图像预处理数据增强的方法。这个模块主要用于对计算机视觉任务中的图像数据进行标准化、转换和增强操作,以满足深度学习模型训练和验证的需求。


本文将基于实例详细介绍torchvision.transforms 模块,实例介绍均使用下面400×300图像。


1. 基本用法

1.1 转换为Tensor
  • ToTensor(): 最基本的、也是必用的方法,将PIL Image或者numpy数组转换为Tensor,并将数据类型转换为float且范围调整至[0, 1]。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.ToTensor()

print(transform(img))
  • 输出:
tensor([[[0.1647, 0.1647, 0.1647,  ..., 0.2824, 0.2824, 0.2824],
         [0.1647, 0.1647, 0.1647,  ..., 0.2824, 0.2824, 0.2824],
         [0.1686, 0.1686, 0.1647,  ..., 0.2824, 0.2824, 0.2824],
         ...,
         [0.1647, 0.1647, 0.1647,  ..., 0.2039, 0.2000, 0.1961],
         [0.1608, 0.1608, 0.1608,  ..., 0.2000, 0.1961, 0.1961],
         [0.1569, 0.1569, 0.1569,  ..., 0.1961, 0.1922, 0.1922]],

        [[0.2078, 0.2078, 0.2078,  ..., 0.2627, 0.2627, 0.2627],
         [0.2078, 0.2078, 0.2078,  ..., 0.2627, 0.2627, 0.2627],
         [0.2078, 0.2078, 0.2078,  ..., 0.2627, 0.2627, 0.2627],
         ...,
         [0.2039, 0.2078, 0.2078,  ..., 0.2235, 0.2235, 0.2235],
         [0.2000, 0.2039, 0.2000,  ..., 0.2196, 0.2196, 0.2196],
         [0.2000, 0.1961, 0.1961,  ..., 0.2196, 0.2196, 0.2196]],

        [[0.2431, 0.2431, 0.2431,  ..., 0.2510, 0.2510, 0.2510],
         [0.2431, 0.2431, 0.2431,  ..., 0.2510, 0.2510, 0.2471],
         [0.2431, 0.2431, 0.2431,  ..., 0.2510, 0.2510, 0.2510],
         ...,
         [0.2392, 0.2392, 0.2392,  ..., 0.2392, 0.2392, 0.2392],
         [0.2353, 0.2353, 0.2353,  ..., 0.2353, 0.2392, 0.2392],
         [0.2314, 0.2314, 0.2314,  ..., 0.2353, 0.2353, 0.2353]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]])

这里可能有人会疑惑,为什么输出的tensor有4个通道?这是因为原图格式为RGBA:<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=400x300 at 0x25BC96CA950>,最后一个通道为透明度(全为1.0000的那个通道)。

1.2 图像大小调整
  • Resize(size): 调整图像到指定尺寸。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.Resize((100,200))

transform(img).show()
  • 输出:

这里需要注意.Resize()输入元组为(height, width)。

1.3 随机裁剪
  • RandomCrop(size): 随机裁剪图像为给定尺寸。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.RandomCrop((200,300))

transform(img).show()
  • 输出:

    会有多种随机输出,数据增强的主要手段之一。
1.4 中心裁剪
  • CenterCrop(size): 从图像中心裁剪出指定尺寸的区域。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.CenterCrop((200,200)) #也可以简写为transform = torchvision.transforms.CenterCrop(200)

transform(img).show()
  • 输出:

    不同于随机裁剪,中心裁剪只有一个确定的输出。
1.5 随机翻转
  • RandomHorizontalFlip(p): 水平方向上以概率p进行随机翻转。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.RandomHorizontalFlip(p=1)

transform(img).show()
  • 输出:
1.6 随机旋转
  • RandomRotation(degrees): 随机旋转图像。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.RandomRotation(90)

transform(img).show()
  • 输出:
1.7 填充
  • Pad(padding, fill, padding_mode): 在图像周围添加指定宽度的填充。
import torchvision
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.Pad((100,100,200,200),padding_mode='edge')

transform(img).show()

Pad的参数说明:

  • padding:这是一个表示填充大小的元组。它可以是单个整数值(在所有边都应用相同的填充)或者一个包含四个整数的元组 (padding_left, padding_right, padding_top, padding_bottom)分别表示左、右、上、下的填充大小。
  • fill:填充像素的颜色值,默认为0,即黑色(对于灰度图和RGB图,分别代表灰度值和RGB三通道颜色)。这个值可以是整数(如0-255之间的数字)、浮点数(在归一化到[0, 1]范围的图像中使用)或者是元组(在RGB图像中,每个元素分别代表R、G、B通道的填充颜色)。
  • padding_mode:定义填充的方式,可选选项包括:
    - constant:用给定的常数值填充。
    - edge:复制图像边缘的像素值进行填充。
    - reflect:以镜像的方式从图像边缘反射像素来填充。
    - replicate:与’edge’类似,但是不考虑镜像对称,简单地重复最接近边界的像素值。
  • 输出:
1.8 组合变换
  • Compose(transforms): 将多个transform操作有序地组合在一起执行。
import torchvision
from torchvision.transforms import CenterCrop, RandomRotation, RandomHorizontalFlip
import PIL

img = PIL.Image.open('car.png')
transform = torchvision.transforms.Compose([CenterCrop(200),RandomRotation(90), RandomHorizontalFlip(p=1)])

transform(img).show()
  • 输出:

2. 进阶用法

2.1 归一化
  • Normalize(mean, std): 将图像按照指定均值和标准差进行归一化。具体的处理方法为normalized_image = (original_image - mean) / std
import PIL
from torchvision.transforms import Compose,Normalize,ToTensor

img = PIL.Image.open('car.png')
transform = Compose([ToTensor(),Normalize(mean=[0.1,0.1,0.1,0],std=[0.3,0.3,0.3,1])])
print(transform(img))
  • 输出:
tensor([[[0.2157, 0.2157, 0.2157,  ..., 0.6078, 0.6078, 0.6078],
         [0.2157, 0.2157, 0.2157,  ..., 0.6078, 0.6078, 0.6078],
         [0.2288, 0.2288, 0.2157,  ..., 0.6078, 0.6078, 0.6078],
         ...,
         [0.2157, 0.2157, 0.2157,  ..., 0.3464, 0.3333, 0.3203],
         [0.2026, 0.2026, 0.2026,  ..., 0.3333, 0.3203, 0.3203],
         [0.1895, 0.1895, 0.1895,  ..., 0.3203, 0.3072, 0.3072]],

        [[0.3595, 0.3595, 0.3595,  ..., 0.5425, 0.5425, 0.5425],
         [0.3595, 0.3595, 0.3595,  ..., 0.5425, 0.5425, 0.5425],
         [0.3595, 0.3595, 0.3595,  ..., 0.5425, 0.5425, 0.5425],
         ...,
         [0.3464, 0.3595, 0.3595,  ..., 0.4118, 0.4118, 0.4118],
         [0.3333, 0.3464, 0.3333,  ..., 0.3987, 0.3987, 0.3987],
         [0.3333, 0.3203, 0.3203,  ..., 0.3987, 0.3987, 0.3987]],

        [[0.4771, 0.4771, 0.4771,  ..., 0.5033, 0.5033, 0.5033],
         [0.4771, 0.4771, 0.4771,  ..., 0.5033, 0.5033, 0.4902],
         [0.4771, 0.4771, 0.4771,  ..., 0.5033, 0.5033, 0.5033],
         ...,
         [0.4641, 0.4641, 0.4641,  ..., 0.4641, 0.4641, 0.4641],
         [0.4510, 0.4510, 0.4510,  ..., 0.4510, 0.4641, 0.4641],
         [0.4379, 0.4379, 0.4379,  ..., 0.4510, 0.4510, 0.4510]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]])

然后我们再把它重新转换回图像:

import PIL
from torchvision.transforms import Compose,Normalize,ToTensor, ToPILImage

img = PIL.Image.open('car.png')
transform = Compose([ToTensor(),Normalize(mean=[0.1,0.1,0.1,0],std=[0.3,0.3,0.3,1])])
tensor = transform(img)
pil_image = ToPILImage()
pil_image(tensor).show()
  • 输出图像为:
2.2 色彩空间转换
  • Grayscale(num_output_channels=1): 将图像转换为灰度图。
import PIL
import torchvision

img = PIL.Image.open('car.png')
transform = torchvision.transforms.Grayscale()
transform(img).show()
  • 输出:
2.3 颜色抖动
  • ColorJitter(brightness, contrast, saturation, hue) :用于对图像的颜色属性进行随机抖动,具体包括亮度、对比度、饱和度以及色调(hue)的变化。
import PIL
import torchvision

img = PIL.Image.open('car.png')
transform = torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.6, hue=0.4)
transform(img).show()

ColorJitter的参数说明:

  • brightness: 亮度调整因子,默认值为0(不改变)。给定一个浮点数(例如0.2),会在 [max(1 - brightness, 0), 1 + brightness] 范围内随机选择一个比例来调整图像整体的亮度。
  • contrast: 对比度调整因子,默认值也为0(不改变)。给定一个浮点数(例如0.2),会在相应范围内随机选择一个比例来调整图像的整体对比度。对比度的变化会影响图像中所有像素的相对亮度差异。
  • saturation: 饱和度调整因子,同样默认0。例如0.6,意味着饱和度将在原图的基础上乘以一个范围在 [1 - saturation, 1 + saturation] 内的随机系数。饱和度越高,颜色越鲜艳;反之则趋向于灰色调。
  • hue: 色调调整因子,默认也是0(即不改变色调)。当设置为非零值如0.4时,会随机改变图像的色调(色彩的色相角度)。这对于模拟光照条件变化或色彩偏移非常有用。
  • 因为ColorJitter的调整参数是根据输入随机选择,因此输出也不唯一:

2.4 随机仿射
  • RandomAffine(degrees, translate, scale, shear):它实现了对输入图像进行随机的仿射变换,包括旋转、缩放、剪切和平移等操作。与前面介绍的方法有部分重复,不再详细说明。这里仅说明其参数:
  • degrees: 表示图像随机旋转的角度范围,可以是单个数值表示固定角度或者一个元组来指定随机选择的角度区间。
  • translate: 指定水平和垂直方向上的随机平移幅度,以图像宽度或高度的百分比形式给出。
  • scale: 指定随机缩放的比例范围,输入的是一个包含最小和最大缩放因子的元组。
  • shear: 控制图像在两个坐标轴之间的随机剪切角度范围。
2.5 透视变换
  • RandomPerspective(distortion_scale, p, interpolation, fill): 透视变换能够模拟相机位置、视角或物体距离变化导致的三维空间到二维图像投影的变化,从而增加模型对这类几何变换的鲁棒性。
import PIL
import torchvision

img = PIL.Image.open('car.png')
transform = torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=1.0, interpolation=2, fill=0)
transform(img).show()

RandomPerspective参数说明:

  • distortion_scale:控制透视变换的强度,数值越大,图像扭曲程度越强。
  • p:概率参数,表示该变换应用于每个样本的概率,默认值是1.0,即总会应用透视变换。
  • interpolation:插值方式,用于确定如何从原始像素生成新像素。默认是2,对应于 PIL.Image.BILINEAR 双线性插值。
  • fill:当图像边界因变换而扩大时填充的颜色,默认是0。

透视变换的具体效果会随机产生,并且不会改变图像的尺寸大小,但可能会造成图像某些部分的拉伸、压缩或者移位。输出为:

2.6 自定义变换
  • transforms.Lambda():允许用户直接传入一个函数作为变换操作。
import torchvision

img = PIL.Image.open('car.png')

def image_operation(image):
    return torchvision.transforms.ToTensor()(image)**0.5  #对像素值进行0.5次方

transform = torchvision.transforms.Lambda(image_operation)
torchvision.transforms.ToPILImage()(transform(img)).show()
  • 输出:

以上罗列了torchvision.transforms的各种应用方法,实际使用时应根据具体问题的需求灵活选择和组合这些变换方法。


相关文章
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【单点知识】基于实例详解PyTorch中的DataLoader类
【单点知识】基于实例详解PyTorch中的DataLoader类
700 2
|
8月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
261 0
|
8月前
|
PyTorch 算法框架/工具
Pytorch中最大池化层Maxpool的作用说明及实例使用(附代码)
Pytorch中最大池化层Maxpool的作用说明及实例使用(附代码)
767 0
|
数据采集 PyTorch 数据处理
Pytorch学习笔记(3):图像的预处理(transforms)
Pytorch学习笔记(3):图像的预处理(transforms)
1522 1
Pytorch学习笔记(3):图像的预处理(transforms)
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
700 0
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch通过实例详细剖析CNN
基于Pytorch通过实例详细剖析CNN
88 1
基于Pytorch通过实例详细剖析CNN
|
8月前
|
机器学习/深度学习 算法 大数据
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
119 3
|
8月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch用GAN生成手写数字实例(附代码)
基于Pytorch用GAN生成手写数字实例(附代码)
206 0
|
8月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch的机器学习Regression问题实例(附源码)
基于Pytorch的机器学习Regression问题实例(附源码)
98 1
|
8月前
|
机器学习/深度学习 自然语言处理 算法
PyTorch实例:简单线性回归的训练和反向传播解析
PyTorch实例:简单线性回归的训练和反向传播解析
PyTorch实例:简单线性回归的训练和反向传播解析