【基础实操】借用torch自带网络进行训练自己的图像数据

简介: 【基础实操】借用torch自带网络进行训练自己的图像数据

前言

  在本文里将为大家带来如何进行使用pytorch中的自带的深度网络进行训练自己的数据。本文讲解可分两部分,第一部分为大家介绍如何进行目录式的读取自己的数据;第二部分为大家介绍如何进行更改为其他网络进行调试。(alexnet\densenet\mnasnet\moblienet\resnet\shufflenet\squeezenet\vgg)

目录式读取

  由于大家在做图像分类的时候,一般是往把搜集到的同类图像放置在同一个文件夹内,因此我们在这里采用目录式读取自己制作的数据集进行训练网络。

数据组成:

image.png

  在这里我采用鲜花数据集为基础数据集并对此数据集进行修改。在鲜花数据集中我们确定总数据类别为5类,在训练集中每一类中的图像数为500涨,在测试集中的每一类的图像数为150张,对训练测试内的图像进行修改大小为224x224x3。

  参考pytorch官网示例,我们可以将示例修改进行如下修改:

ini

复制代码

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
train_root = './datas/train/'
test_root = './datas/test/'
# 将文件夹的内容载入dataset
train_dataset = torchvision.datasets.ImageFolder(root=train_root, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root=test_root, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

  这这一步的时候我们可捎带的将超参数进行设置一下,如下设置:

ini

复制代码

learning_rate = 0.1
batch_size = 64
epochs = 100
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

  由于训练集是需要训练而测试集不需要进行训练,那么可参考官网的示例分别对训练集的操作和测试集的操作保持不变如下:

scss

复制代码

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

更改网络

  在pytorch的自带网络大家可与自行将复制出来,这样可与避免因自己的更改导致后续因为再次使用出现不必要的BUG,pytorch自带的models路径为:

   envs\pytorch\Lib\site-packages\torchvision\models

image.png

  大家可与设置model为自己需要调用的网络,在这里我们以vgg网络系列的vgg11为例子为大家介绍如何进行训练网络。我们依旧保持官网示例中的SGD训练函数作为optimizer,然后将各参数导入到train和test中进行训练自己的数据。由于网络比较多,我在这里就不一一为大家介绍了。

  大家可与移步我的Github

image.png

css

复制代码

from myNets.vgg import vgg11  # 可更换为其他
if __name__ == "__main__":
    model = vgg11(num_classes=5)  # 可更换为其他
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)
    print("Done!")

拓展

  在上一期:【实操】涨点神器你还不会,快点进来学习Label Smooth我们介绍了Label Smooth操作,大家可与尝试自行更改进行综合比对测试使用不同的Label Smooth操作对结果的影响,也可更换其他的学习率的衰减函数进行测试。


相关文章
|
24天前
|
机器学习/深度学习 自然语言处理 语音技术
Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧
本文介绍了Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧,并通过TensorFlow和PyTorch等库展示了实现神经网络的具体示例,涵盖图像识别、语音识别等多个应用场景。
48 8
|
1月前
|
安全 算法 网络安全
量子计算与网络安全:保护数据的新方法
量子计算的崛起为网络安全带来了新的挑战和机遇。本文介绍了量子计算的基本原理,重点探讨了量子加密技术,如量子密钥分发(QKD)和量子签名,这些技术利用量子物理的特性,提供更高的安全性和可扩展性。未来,量子加密将在金融、政府通信等领域发挥重要作用,但仍需克服量子硬件不稳定性和算法优化等挑战。
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
57 3
|
1月前
|
存储 安全 网络安全
云计算与网络安全:保护数据的新策略
【10月更文挑战第28天】随着云计算的广泛应用,网络安全问题日益突出。本文将深入探讨云计算环境下的网络安全挑战,并提出有效的安全策略和措施。我们将分析云服务中的安全风险,探讨如何通过技术和管理措施来提升信息安全水平,包括加密技术、访问控制、安全审计等。此外,文章还将分享一些实用的代码示例,帮助读者更好地理解和应用这些安全策略。
|
21天前
|
弹性计算 安全 容灾
阿里云DTS踩坑经验分享系列|使用VPC数据通道解决网络冲突问题
阿里云DTS作为数据世界高速传输通道的建造者,每周为您分享一个避坑技巧,助力数据之旅更加快捷、便利、安全。本文介绍如何使用VPC数据通道解决网络冲突问题。
76 0
|
1月前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:从漏洞到加密,保护数据的关键步骤
【10月更文挑战第24天】在数字化时代,网络安全和信息安全是维护个人隐私和企业资产的前线防线。本文将探讨网络安全中的常见漏洞、加密技术的重要性以及如何通过提高安全意识来防范潜在的网络威胁。我们将深入理解网络安全的基本概念,学习如何识别和应对安全威胁,并掌握保护信息不被非法访问的策略。无论你是IT专业人士还是日常互联网用户,这篇文章都将为你提供宝贵的知识和技能,帮助你在网络世界中更安全地航行。
|
2月前
|
存储 安全 网络安全
云计算与网络安全:如何保护您的数据
【10月更文挑战第21天】在这篇文章中,我们将探讨云计算和网络安全的关系。随着云计算的普及,网络安全问题日益突出。我们将介绍云服务的基本概念,以及如何通过网络安全措施来保护您的数据。最后,我们将提供一些代码示例,帮助您更好地理解这些概念。
|
1月前
|
机器学习/深度学习 人工智能 自动驾驶
深度学习的奇迹:如何用神经网络识别图像
【10月更文挑战第33天】在这篇文章中,我们将探索深度学习的奇妙世界,特别是卷积神经网络(CNN)在图像识别中的应用。我们将通过一个简单的代码示例,展示如何使用Python和Keras库构建一个能够识别手写数字的神经网络。这不仅是对深度学习概念的直观介绍,也是对技术实践的一次尝试。让我们一起踏上这段探索之旅,看看数据、模型和代码是如何交织在一起,创造出令人惊叹的结果。
33 0
|
2月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
64 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
4天前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
41 17
下一篇
DataWorks