pytorch中的transforms.ToTensor和transforms.Normalize理解

简介: pytorch中的transforms.ToTensor和transforms.Normalize理解

pytorch中的transforms.ToTensor和transforms.Normalize理解🌴

transforms.ToTensor🌵

最近看pytorch时,遇到了对图像数据的归一化,如下图所示:1a897537d91ae638717214c5c592b59c.png

该怎么理解这串代码呢?我们一句一句的来看,先看transforms.ToTensor(),我们可以先转到官方给的定义,如下图所示:f3b92dc81868306ba1095add6dcee906.png

大概的意思就是说,transforms.ToTensor()可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ,具体做法其实就是将原始数据除以255。另外原始数据的shape是(H x W x C),通过transforms.ToTensor()后shape会变为(C x H x W)。这样说我觉得大家应该也是能理解的,这部分并不难,但想着还是用一些例子来加深大家的映像🌽🌽🌽

  • 先导入一些包
import cv2
import numpy as np
import torch
from torchvision import transforms
  • 定义一个数组模型图片,注意数组数据类型需要时np.uint8【官方图示中给出】
data = np.array([
                [[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1]],
                [[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2]],
                [[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3]],
                [[4,4,4],[4,4,4],[4,4,4],[4,4,4],[4,4,4]],
                [[5,5,5],[5,5,5],[5,5,5],[5,5,5],[5,5,5]]
        ],dtype='uint8')

这是可以看看data的shape,注意现在为(W H C)。

image.png


  • 使用transforms.ToTensor()将data进行转换
data = transforms.ToTensor()(data)

这时候我们来看看data中的数据及shape。

6669d9bd03b14903a23d884717bd17b4.png

很明显,数据现在都映射到了[0, 1]之间,并且data的shape发生了变换。

**注意:不知道大家是如何理解三维数组的,这里提供我的一个方法。**🥝🥝🥝


🌼原始的data的shape为(5,5,3),则其表示有5个(5 , 3)的二维数组,即我们把最外层的[]去掉就得到了5个五行三列的数据。


🌼同样的,变换后data的shape为(3,5,5),则其表示有3个(5 , 5)的二维数组,即我们把最外层的[]去掉就得到了3个五行五列的数据。


transforms.Normalize🌵

  相信通过前面的叙述大家应该对transforms.ToTensor有了一定的了解,下面将来说说这个transforms.Normalize🍹🍹🍹同样的,我们先给出官方的定义,如下图所示:308dd2831aa3c5376b60b927a958026c.png

可以看到这个函数的输出output[channel] = (input[channel] - mean[channel]) / std[channel]。这里[channel]的意思是指对特征图的每个通道都进行这样的操作。【mean为均值,std为标准差】接下来我们看第一张图片中的代码,即c9ac9390e825a69f6af3763baac8852e.png

这里的第一个参数(0.5,0.5,0.5)表示每个通道的均值都是0.5,第二个参数(0.5,0.5,0.5)表示每个通道的方差都为0.5。【因为图像一般是三个通道,所以这里的向量都是1x3的🍵🍵🍵】有了这两个参数后,当我们传入一个图像时,就会按照上面的公式对图像进行变换。【注意:这里说图像其实也不够准确,因为这个函数传入的格式不能为PIL Image,我们应该先将其转换为Tensor格式】

说了这么多,那么这个函数到底有什么用呢?我们通过前面的ToTensor已经将数据归一化到了0-1之间,现在又接上了一个Normalize函数有什么用呢?其实Normalize函数做的是将数据变换到了[-1,1]之间。之前的数据为0-1,当取0时,output =(0 - 0.5)/ 0.5 = -1;当取1时,output =(1 - 0.5)/ 0.5 = 1。这样就把数据统一到了[-1,1]之间了🌱🌱🌱那么问题又来了,数据统一到[-1,1]有什么好处呢?数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度。【这句话是再网络上找到最多的解释,自己也不确定其正确性】

读到这里大家是不是以为就完了呢?这里还想和大家唠上一唠🍓🍓🍓上面的两个参数(0.5,0.5,0.5)是怎么得来的呢?这是根据数据集中的数据计算出的均值和标准差,所以往往不同的数据集这两个值是不同的🍏🍏🍏这里再举一个例子帮助大家理解其计算过程。同样采用上文例子中提到的数据。

  • 上文已经得到了经ToTensor转换后的数据,现需要求出该数据每个通道的mean和std。【这一部分建议大家自己运行看看每一步的结果🌵🌵🌵】
# 需要对数据进行扩维,增加batch维度
data = torch.unsqueeze(data,0)    #在pytorch中一般都是(batch,C,H,W)
nb_samples = 0.
#创建3维的空列表
channel_mean = torch.zeros(3)
channel_std = torch.zeros(3)
N, C, H, W = data.shape[:4]
data = data.view(N, C, -1)  #将数据的H,W合并
#展平后,w,h属于第2维度,对他们求平均,sum(0)为将同一纬度的数据累加
channel_mean += data.mean(2).sum(0)  
#展平后,w,h属于第2维度,对他们求标准差,sum(0)为将同一纬度的数据累加
channel_std += data.std(2).sum(0)
#获取所有batch的数据,这里为1
nb_samples += N
#获取同一batch的均值和标准差
channel_mean /= nb_samples
channel_std /= nb_samples
print(channel_mean, channel_std)   #结果为tensor([0.0118, 0.0118, 0.0118]) tensor([0.0057, 0.0057, 0.0057])
  • 将上述得到的mean和std带入公式,计算输出。
for i in range(3):
    data[i] = (data[i] - channel_mean[i]) / channel_std[i]
print(data)

输出结果:

f94e003694ee4c5af69a26cdad55144e.png

从结果可以看出,我们计算的mean和std并不是0.5,且最后的结果也没有在[-1,1]之间。


最后我们再来看一个有意思的例子,我们得到了最终的结果,要是我们想要变回去怎么办,其实很简单啦,就是一个逆运算,即input = std*output + mean,然后再乘上255就可以得到原始的结果了。很多人获取吐槽了,这也叫有趣!!??哈哈哈这里我想说的是另外的一个事,如果我们对一张图像进行了归一化,这时候你用归一化后的数据显示这张图像的时候,会发现同样会是原图。感兴趣的大家可以去试试🥗🥗🥗🥗这里给出一个参考链接:https://blog.csdn.net/xjp_xujiping/article/details/102981117


参考链接1:https://zhuanlan.zhihu.com/p/414242338

参考链接2:https://blog.csdn.net/peacefairy/article/details/108020179






相关文章
|
机器学习/深度学习 编解码
ICCV 2023 超分辨率(Super-Resolution)论文汇总
ICCV 2023 超分辨率(Super-Resolution)论文汇总
835 0
|
数据采集 PyTorch 数据处理
Pytorch学习笔记(3):图像的预处理(transforms)
Pytorch学习笔记(3):图像的预处理(transforms)
2184 1
Pytorch学习笔记(3):图像的预处理(transforms)
|
Rust 算法 Go
【密码学】一文读懂MurMurHash3
本文应该是MurMurHash算法介绍的最后一篇,来一起看一下最新的MurMurHash算法的具体过程,对于最新的算法来说,整个流程和之前的其实也比较相似,这里从维基百科当中找到了伪代码,也就不贴出来Google官方给出的推荐代码了,先来看一下维基百科给出的伪代码,这里只有32位的伪代码。
2894 0
【密码学】一文读懂MurMurHash3
|
机器学习/深度学习 人工智能 自然语言处理
视觉 注意力机制——通道注意力、空间注意力、自注意力
本文介绍注意力机制的概念和基本原理,并站在计算机视觉CV角度,进一步介绍通道注意力、空间注意力、混合注意力、自注意力等。
12650 58
|
机器学习/深度学习 算法 计算机视觉
深度学习目标检测系列:一文弄懂YOLO算法|附Python源码
本文是目标检测系列文章——YOLO算法,介绍其基本原理及实现细节,并用python实现,方便读者上手体验目标检测的乐趣。
53163 0
|
算法 计算机视觉
Opencv学习笔记(六):cv2.resize函数的介绍
这篇文章介绍了OpenCV库中cv2.resize函数的使用方法,包括其参数、插值方式选择以及实际代码示例。
1947 1
Opencv学习笔记(六):cv2.resize函数的介绍
|
并行计算 异构计算
卸载原有的cuda,更新cuda
本文提供了一个更新CUDA版本的详细指南,包括如何查看当前CUDA版本、检查可安装的CUDA版本、卸载旧版本CUDA以及安装新版本的CUDA。
10520 3
卸载原有的cuda,更新cuda
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch学习笔记(八):nn.ModuleList和nn.Sequential函数详解
PyTorch中的nn.ModuleList和nn.Sequential函数,包括它们的语法格式、参数解释和具体代码示例,展示了如何使用这些函数来构建和管理神经网络模型。
1977 1
|
机器学习/深度学习 人工智能 文字识别
ultralytics YOLO11 全新发布!(原理介绍+代码详见+结构框图)
本文详细介绍YOLO11,包括其全新特性、代码实现及结构框图,并提供如何使用NEU-DET数据集进行训练的指南。YOLO11在前代基础上引入了新功能和改进,如C3k2、C2PSA模块和更轻量级的分类检测头,显著提升了模型的性能和灵活性。文中还对比了YOLO11与YOLOv8的区别,并展示了训练过程和结果的可视化
19406 0
|
机器学习/深度学习 并行计算 PyTorch
PyTorch与DistributedDataParallel:分布式训练入门指南
【8月更文第27天】随着深度学习模型变得越来越复杂,单一GPU已经无法满足训练大规模模型的需求。分布式训练成为了加速模型训练的关键技术之一。PyTorch 提供了多种工具来支持分布式训练,其中 DistributedDataParallel (DDP) 是一个非常受欢迎且易用的选择。本文将详细介绍如何使用 PyTorch 的 DDP 模块来进行分布式训练,并通过一个简单的示例来演示其使用方法。
2066 2