使用Python实现深度学习模型:迁移学习与预训练模型

本文涉及的产品
实时计算 Flink 版,1000CU*H 3个月
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
简介: 使用Python实现深度学习模型:迁移学习与预训练模型

迁移学习是一种将已经在一个任务上训练好的模型应用到另一个相关任务上的方法。通过使用预训练模型,迁移学习可以显著减少训练时间并提高模型性能。在本文中,我们将详细介绍如何使用Python和PyTorch进行迁移学习,并展示其在图像分类任务中的应用。

什么是迁移学习?

迁移学习的基本思想是利用在大规模数据集(如ImageNet)上训练好的模型,将其知识迁移到特定的目标任务中。迁移学习通常包括以下步骤:

  • 加载预训练模型:使用已经在大规模数据集上训练好的模型。
  • 微调模型:根据目标任务的数据集对模型进行微调。

实现步骤

步骤 1:导入所需库

首先,我们需要导入所需的Python库:PyTorch用于构建和训练深度学习模型,Torchvision用于加载预训练模型和数据处理。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt

步骤 2:准备数据

我们将使用CIFAR-10数据集作为示例数据集。CIFAR-10是一个常用于图像分类任务的基准数据集,包含10个类别的60000张32x32彩色图像。

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应预训练模型的输入要求
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 下载并加载训练和测试数据
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

步骤 3:加载预训练模型

我们将使用在ImageNet数据集上预训练的ResNet-18模型,并对其进行微调以适应CIFAR-10数据集。

# 加载预训练的ResNet-18模型
model = models.resnet18(pretrained=True)

# 修改模型的最后一层以适应CIFAR-10数据集
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# 将模型移动到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

步骤 4:定义损失函数和优化器

我们选择交叉熵损失函数(Cross Entropy Loss)作为模型训练的损失函数,并使用Adam优化器进行优化。

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

步骤 5:训练模型

我们使用定义的预训练模型对CIFAR-10数据集进行训练。

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
            running_loss = 0.0

print('Finished Training')

步骤 6:评估模型

训练完成后,我们可以在测试数据集上评估模型的性能。

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

可视化一些预测结果

我们可以可视化一些模型的预测结果。

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 获取一些随机测试图像
dataiter = iter(test_loader)
images, labels = dataiter.next()

# 打印图像
imshow(torchvision.utils.make_grid(images))

# 打印标签
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]}' for j in range(4)))

# 打印预测结果
outputs = model(images.to(device))
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]}' for j in range(4)))

总结

通过本教程,你学会了如何使用Python和PyTorch进行迁移学习,并在CIFAR-10数据集上应用预训练的ResNet-18模型进行图像分类。迁移学习是一种强大的技术,能够显著减少训练时间并提高模型性能,广泛应用于各种深度学习任务中。希望本教程能够帮助你理解迁移学习的基本原理和实现方法,并启发你在实际应用中使用迁移学习解决各种问题。

目录
相关文章
|
1月前
|
存储 Java 数据处理
(numpy)Python做数据处理必备框架!(一):认识numpy;从概念层面开始学习ndarray数组:形状、数组转置、数值范围、矩阵...
Numpy是什么? numpy是Python中科学计算的基础包。 它是一个Python库,提供多维数组对象、各种派生对象(例如掩码数组和矩阵)以及用于对数组进行快速操作的各种方法,包括数学、逻辑、形状操作、排序、选择、I/0 、离散傅里叶变换、基本线性代数、基本统计运算、随机模拟等等。 Numpy能做什么? numpy的部分功能如下: ndarray,一个具有矢量算术运算和复杂广播能力的快速且节省空间的多维数组 用于对整组数据进行快速运算的标准数学函数(无需编写循环)。 用于读写磁盘数据的工具以及用于操作内存映射文件的工具。 线性代数、随机数生成以及傅里叶变换功能。 用于集成由C、C++
302 0
|
1月前
|
存储 JavaScript Java
(Python基础)新时代语言!一起学习Python吧!(四):dict字典和set类型;切片类型、列表生成式;map和reduce迭代器;filter过滤函数、sorted排序函数;lambda函数
dict字典 Python内置了字典:dict的支持,dict全称dictionary,在其他语言中也称为map,使用键-值(key-value)存储,具有极快的查找速度。 我们可以通过声明JS对象一样的方式声明dict
169 1
|
1月前
|
算法 Java Docker
(Python基础)新时代语言!一起学习Python吧!(三):IF条件判断和match匹配;Python中的循环:for...in、while循环;循环操作关键字;Python函数使用方法
IF 条件判断 使用if语句,对条件进行判断 true则执行代码块缩进语句 false则不执行代码块缩进语句,如果有else 或 elif 则进入相应的规则中执行
259 1
|
1月前
|
存储 Java 索引
(Python基础)新时代语言!一起学习Python吧!(二):字符编码由来;Python字符串、字符串格式化;list集合和tuple元组区别
字符编码 我们要清楚,计算机最开始的表达都是由二进制而来 我们要想通过二进制来表示我们熟知的字符看看以下的变化 例如: 1 的二进制编码为 0000 0001 我们通过A这个字符,让其在计算机内部存储(现如今,A 字符在地址通常表示为65) 现在拿A举例: 在计算机内部 A字符,它本身表示为 65这个数,在计算机底层会转为二进制码 也意味着A字符在底层表示为 1000001 通过这样的字符表示进行转换,逐步发展为拥有127个字符的编码存储到计算机中,这个编码表也被称为ASCII编码。 但随时代变迁,ASCII编码逐渐暴露短板,全球有上百种语言,光是ASCII编码并不能够满足需求
146 4
|
1月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
2月前
|
JavaScript Java 大数据
基于python的网络课程在线学习交流系统
本研究聚焦网络课程在线学习交流系统,从社会、技术、教育三方面探讨其发展背景与意义。系统借助Java、Spring Boot、MySQL、Vue等技术实现,融合云计算、大数据与人工智能,推动教育公平与教学模式创新,具有重要理论价值与实践意义。
|
10月前
|
机器学习/深度学习 运维 安全
深度学习在安全事件检测中的应用:守护数字世界的利器
深度学习在安全事件检测中的应用:守护数字世界的利器
413 22
|
7月前
|
机器学习/深度学习 编解码 人工智能
计算机视觉五大技术——深度学习在图像处理中的应用
深度学习利用多层神经网络实现人工智能,计算机视觉是其重要应用之一。图像分类通过卷积神经网络(CNN)判断图片类别,如“猫”或“狗”。目标检测不仅识别物体,还确定其位置,R-CNN系列模型逐步优化检测速度与精度。语义分割对图像每个像素分类,FCN开创像素级分类范式,DeepLab等进一步提升细节表现。实例分割结合目标检测与语义分割,Mask R-CNN实现精准实例区分。关键点检测用于人体姿态估计、人脸特征识别等,OpenPose和HRNet等技术推动该领域发展。这些方法在效率与准确性上不断进步,广泛应用于实际场景。
1041 64
计算机视觉五大技术——深度学习在图像处理中的应用
|
11月前
|
机器学习/深度学习 传感器 数据采集
深度学习在故障检测中的应用:从理论到实践
深度学习在故障检测中的应用:从理论到实践
975 6

推荐镜像

更多