计算机视觉PyTorch迁移学习 - (一)

简介: 计算机视觉PyTorch迁移学习 - (一)

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效


1.迁移学习原理与流程


图像迁移学习一共分两类:


1.1微调


选择使用Imagenet数据集训练好的模型,更新模型中所有参数


在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。

创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。

为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。

在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

3573200247134a9183b2e36947dff546.png

  1. 当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力

1.2特征提取


选择使用Imagenet数据集训练好的模型,更新模型预测的最后一层的参数

流程与微调相似,只是更新参数的层不同


2.图像增强


图像增强(image augmentation)指通过剪切、旋转/反射/翻转变换、缩放变换、平移变换、尺度变换、对比度变换、噪声扰动、颜色变换等一种或多种组合数据增强变换的方式来增加数据集的大小。图像增强的意义是通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模,而且随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力

原始图像


import matplotlib.pyplot as plt
import cv2 as cv
from torchvision import transforms as transforms
image="data.jpg"
img=cv.imread(image)
b,g,r=cv.split(img)
img=cv.merge([r,g,b])
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

da4b1873662f4e08ad5b39da39c4d442.png

2.1比例缩放


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
resize = transforms.Resize([125,125])
img = resize(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

41f74eb5b5784d15ae8f18a86c106a1b.png


2.2位置裁剪


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
crop = transforms.RandomCrop([100,100])
img = crop(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()


04c188cf754248e2b9a3b3eb0e23edf5.png


2.3水平/垂直翻转


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
HF = transforms.RandomHorizontalFlip()
imgHF = HF(img)
VF = transforms.RandomVerticalFlip()
imgVF = VF(img)
title=['水平','垂直']
img=[imgHF,imgVF]
for i in range(2):
        plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
        plt.title(title[i])
        plt.xticks([]), plt.yticks([])
plt.show()


c434335f64334b11898a51618f2e0c95.png

2.4角度旋转


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
rotation = transforms.RandomRotation(45)
img = rotation(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

107d0ca14f324e6ab9aaa466f3fb61d1.png


2.5色度、亮度、饱和度、对比度

import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
#色度
transform_1=transforms.ColorJitter(brightness=1)
img_1=transform_1(img)
#亮度
transform_2=transforms.ColorJitter(contrast=1)
img_2=transform_2(img)
#饱和度
transform_3=transforms.ColorJitter(saturation=0.5)
img_3=transform_3(img)
#对比度
transform_4=transforms.ColorJitter(hue=0.5)
img_4=transform_4(img)
title=['色度','亮度','饱和度','对比度']
img=[img_1,img_2,img_3,img_4]
for i in range(2):
        plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
        plt.title(title[i])
        plt.xticks([]), plt.yticks([])
for i in range(2):
        plt.subplot(2, 2, i + 3), plt.imshow(img[i+2], 'gray')
        plt.title(title[i+2])
        plt.xticks([]), plt.yticks([])        
plt.show()


8d7825bd24b34ed489bcf3155167488c.png

2.6灰度化


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
gray = transforms.RandomGrayscale(p=0.5)
img = gray(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

df6bf2d3ec464d1e9993aabaad25e832.png

2.7Padding


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
pad = transforms.Pad((0,(img.size[0]-img.size[1])//2))
img = pad(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

9540a694731d4de0b6639ed003c375db.png

2.8模型中图像增强数据预处理


datatrian=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.Resize(image_size),
                    transforms.CenterCrop(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])


相关文章
|
3月前
|
PyTorch Linux 算法框架/工具
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
这篇文章是关于如何使用Anaconda进行Python环境管理,包括下载、安装、配置环境变量、创建多版本Python环境、安装PyTorch以及使用Jupyter Notebook的详细指南。
413 1
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
|
7月前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
3月前
|
机器学习/深度学习 缓存 PyTorch
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
这篇文章是关于如何下载、安装和配置Miniconda,以及如何使用Miniconda创建和管理Python环境的详细指南。
687 0
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
本文介绍了几种常用的计算机视觉注意力机制及其PyTorch实现,包括SENet、CBAM、BAM、ECA-Net、SA-Net、Polarized Self-Attention、Spatial Group-wise Enhance和Coordinate Attention等,每种方法都附有详细的网络结构说明和实验结果分析。通过这些注意力机制的应用,可以有效提升模型在目标检测任务上的性能。此外,作者还提供了实验数据集的基本情况及baseline模型的选择与实验结果,方便读者理解和复现。
136 0
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
|
7月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】43. 算法优化之Adam算法【RMSProp算法与动量法的结合】介绍及其Pytorch实现
【从零开始学习深度学习】43. 算法优化之Adam算法【RMSProp算法与动量法的结合】介绍及其Pytorch实现
|
3月前
|
机器学习/深度学习 人工智能 TensorFlow
浅谈计算机视觉新手的学习路径
浅谈计算机视觉新手的学习路径
33 0
|
5月前
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
44 0
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
|
7月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
7月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用

热门文章

最新文章