计算机视觉PyTorch迁移学习 - (一)

简介: 计算机视觉PyTorch迁移学习 - (一)

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效


1.迁移学习原理与流程


图像迁移学习一共分两类:


1.1微调


选择使用Imagenet数据集训练好的模型,更新模型中所有参数


在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。

创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。

为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。

在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

3573200247134a9183b2e36947dff546.png

  1. 当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力

1.2特征提取


选择使用Imagenet数据集训练好的模型,更新模型预测的最后一层的参数

流程与微调相似,只是更新参数的层不同


2.图像增强


图像增强(image augmentation)指通过剪切、旋转/反射/翻转变换、缩放变换、平移变换、尺度变换、对比度变换、噪声扰动、颜色变换等一种或多种组合数据增强变换的方式来增加数据集的大小。图像增强的意义是通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模,而且随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力

原始图像


import matplotlib.pyplot as plt
import cv2 as cv
from torchvision import transforms as transforms
image="data.jpg"
img=cv.imread(image)
b,g,r=cv.split(img)
img=cv.merge([r,g,b])
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

da4b1873662f4e08ad5b39da39c4d442.png

2.1比例缩放


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
resize = transforms.Resize([125,125])
img = resize(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

41f74eb5b5784d15ae8f18a86c106a1b.png


2.2位置裁剪


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
crop = transforms.RandomCrop([100,100])
img = crop(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()


04c188cf754248e2b9a3b3eb0e23edf5.png


2.3水平/垂直翻转


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
HF = transforms.RandomHorizontalFlip()
imgHF = HF(img)
VF = transforms.RandomVerticalFlip()
imgVF = VF(img)
title=['水平','垂直']
img=[imgHF,imgVF]
for i in range(2):
        plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
        plt.title(title[i])
        plt.xticks([]), plt.yticks([])
plt.show()


c434335f64334b11898a51618f2e0c95.png

2.4角度旋转


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
rotation = transforms.RandomRotation(45)
img = rotation(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

107d0ca14f324e6ab9aaa466f3fb61d1.png


2.5色度、亮度、饱和度、对比度

import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
#色度
transform_1=transforms.ColorJitter(brightness=1)
img_1=transform_1(img)
#亮度
transform_2=transforms.ColorJitter(contrast=1)
img_2=transform_2(img)
#饱和度
transform_3=transforms.ColorJitter(saturation=0.5)
img_3=transform_3(img)
#对比度
transform_4=transforms.ColorJitter(hue=0.5)
img_4=transform_4(img)
title=['色度','亮度','饱和度','对比度']
img=[img_1,img_2,img_3,img_4]
for i in range(2):
        plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
        plt.title(title[i])
        plt.xticks([]), plt.yticks([])
for i in range(2):
        plt.subplot(2, 2, i + 3), plt.imshow(img[i+2], 'gray')
        plt.title(title[i+2])
        plt.xticks([]), plt.yticks([])        
plt.show()


8d7825bd24b34ed489bcf3155167488c.png

2.6灰度化


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
gray = transforms.RandomGrayscale(p=0.5)
img = gray(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

df6bf2d3ec464d1e9993aabaad25e832.png

2.7Padding


import matplotlib.pyplot as plt
from torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
pad = transforms.Pad((0,(img.size[0]-img.size[1])//2))
img = pad(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()

9540a694731d4de0b6639ed003c375db.png

2.8模型中图像增强数据预处理


datatrian=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.Resize(image_size),
                    transforms.CenterCrop(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])


相关文章
|
5天前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
72 0
|
5天前
|
机器学习/深度学习 存储 数据库
Python3 OpenCV4 计算机视觉学习手册:6~11(5)
Python3 OpenCV4 计算机视觉学习手册:6~11(5)
58 0
|
5天前
|
机器学习/深度学习 算法 数据可视化
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)-2
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)
107 0
|
9月前
|
机器学习/深度学习 缓存 监控
Pytorch学习笔记(7):优化器、学习率及调整策略、动量
Pytorch学习笔记(7):优化器、学习率及调整策略、动量
426 0
Pytorch学习笔记(7):优化器、学习率及调整策略、动量
|
5天前
|
机器学习/深度学习 Ubuntu Linux
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)-1
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)
62 1
|
5天前
|
机器学习/深度学习 算法 数据挖掘
Python3 OpenCV4 计算机视觉学习手册:6~11(2)
Python3 OpenCV4 计算机视觉学习手册:6~11(2)
78 0
|
5天前
|
算法 计算机视觉 索引
Python3 OpenCV4 计算机视觉学习手册:1~5
Python3 OpenCV4 计算机视觉学习手册:1~5
55 0
|
5天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
5天前
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
23 0
|
5天前
|
机器学习/深度学习 算法 PyTorch
PyTorch使用Tricks:学习率衰减 !!
PyTorch使用Tricks:学习率衰减 !!
59 0