【基础实操】借用torch自带网络进行训练自己的图像数据

简介: 【基础实操】借用torch自带网络进行训练自己的图像数据

前言

  在本文里将为大家带来如何进行使用pytorch中的自带的深度网络进行训练自己的数据。本文讲解可分两部分,第一部分为大家介绍如何进行目录式的读取自己的数据;第二部分为大家介绍如何进行更改为其他网络进行调试。(alexnet\densenet\mnasnet\moblienet\resnet\shufflenet\squeezenet\vgg)

目录式读取

  由于大家在做图像分类的时候,一般是往把搜集到的同类图像放置在同一个文件夹内,因此我们在这里采用目录式读取自己制作的数据集进行训练网络。

数据组成:

image.png

  在这里我采用鲜花数据集为基础数据集并对此数据集进行修改。在鲜花数据集中我们确定总数据类别为5类,在训练集中每一类中的图像数为500涨,在测试集中的每一类的图像数为150张,对训练测试内的图像进行修改大小为224x224x3。

  参考pytorch官网示例,我们可以将示例修改进行如下修改:

ini

复制代码

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
train_root = './datas/train/'
test_root = './datas/test/'
# 将文件夹的内容载入dataset
train_dataset = torchvision.datasets.ImageFolder(root=train_root, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root=test_root, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

  这这一步的时候我们可捎带的将超参数进行设置一下,如下设置:

ini

复制代码

learning_rate = 0.1
batch_size = 64
epochs = 100
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

  由于训练集是需要训练而测试集不需要进行训练,那么可参考官网的示例分别对训练集的操作和测试集的操作保持不变如下:

scss

复制代码

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

更改网络

  在pytorch的自带网络大家可与自行将复制出来,这样可与避免因自己的更改导致后续因为再次使用出现不必要的BUG,pytorch自带的models路径为:

   envs\pytorch\Lib\site-packages\torchvision\models

image.png

  大家可与设置model为自己需要调用的网络,在这里我们以vgg网络系列的vgg11为例子为大家介绍如何进行训练网络。我们依旧保持官网示例中的SGD训练函数作为optimizer,然后将各参数导入到train和test中进行训练自己的数据。由于网络比较多,我在这里就不一一为大家介绍了。

  大家可与移步我的Github

image.png

css

复制代码

from myNets.vgg import vgg11  # 可更换为其他
if __name__ == "__main__":
    model = vgg11(num_classes=5)  # 可更换为其他
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)
    print("Done!")

拓展

  在上一期:【实操】涨点神器你还不会,快点进来学习Label Smooth我们介绍了Label Smooth操作,大家可与尝试自行更改进行综合比对测试使用不同的Label Smooth操作对结果的影响,也可更换其他的学习率的衰减函数进行测试。


相关文章
|
4天前
|
网络协议 算法 C语言
C语言在网络编程中如何实现数据完整性
C语言在网络编程中如何实现数据完整性
12 0
|
5天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:保护你的数据,保护你的未来
【5月更文挑战第30天】在数字化的世界中,网络安全和信息安全是每个人都需要关注的问题。本文将深入探讨网络安全漏洞、加密技术以及安全意识等方面的问题,帮助读者了解如何保护自己的数据,防止网络攻击。
|
5天前
|
安全 网络安全 量子技术
网络安全与信息安全:保护数据的关键策略
【5月更文挑战第30天】 在数字化时代,网络安全和信息安全已成为维护个人隐私、企业资产和国家安全不可或缺的一环。本文将深入探讨网络安全漏洞的概念、加密技术的最新进展以及提升安全意识的重要性。通过分析当前的网络威胁和挑战,我们展示了如何利用多层次的安全措施来防御潜在的攻击。文章不仅提供了对现有安全技术的深刻见解,还强调了教育和个人责任在构建坚固防线中的作用。
16 3
|
4天前
|
机器学习/深度学习
简单通用:视觉基础网络最高3倍无损训练加速,清华EfficientTrain++入选TPAMI 2024
【5月更文挑战第30天】清华大学研究团队提出的EfficientTrain++是一种新型训练方法,旨在加速视觉基础网络(如ResNet、ConvNeXt、DeiT)的训练,最高可达3倍速度提升,同时保持模型准确性。该方法基于傅里叶谱裁剪和动态数据增强,实现了课程学习的创新应用。在ImageNet-1K/22K数据集上,EfficientTrain++能有效减少多种模型的训练时间,且在自监督学习任务中表现出色。尽管面临适应性与稳定性的挑战,EfficientTrain++为深度学习模型的高效训练开辟了新途径,对学术和工业界具有重要意义。
12 4
|
5天前
|
监控 安全 算法
保护数据:网络安全与信息安全探究
网络安全和信息安全日益成为当今社会关注的焦点话题。本文从网络安全漏洞、加密技术和安全意识等方面出发,深入探讨了如何保护个人和组织的数据安全。通过对不同类型的网络威胁和最新的安全技术进行分析,提出了有效的防范措施和加强安全意识的建议,以期为读者提供全面的网络安全知识,并帮助他们更好地保护自己的信息资产。
15 1
|
5天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:保护您的数据和隐私
【5月更文挑战第30天】在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。了解网络安全漏洞、加密技术和安全意识等方面的知识,对于保护我们的在线数据和隐私至关重要。本文将探讨这些主题,并提供有关如何保护自己免受网络攻击的建议。
|
5天前
|
存储 安全 网络安全
网络安全与信息安全:保护数据的关键策略
【5月更文挑战第30天】 在数字化时代,网络安全和信息安全已成为维护数据完整性、确保信息传输安全性的重要议题。本文旨在探讨网络安全漏洞的概念、加密技术的应用以及提升安全意识的重要性。通过对这些关键点的深入分析,我们旨在为读者提供一套综合性的策略,以增强个人和组织在面对日益复杂的网络威胁时的防护能力。
|
15天前
|
消息中间件 Java Linux
2024年最全BATJ真题突击:Java基础+JVM+分布式高并发+网络编程+Linux(1),2024年最新意外的惊喜
2024年最全BATJ真题突击:Java基础+JVM+分布式高并发+网络编程+Linux(1),2024年最新意外的惊喜
|
13天前
|
JSON 安全 网络协议
【Linux 网络】网络基础(二)(应用层协议:HTTP、HTTPS)-- 详解
【Linux 网络】网络基础(二)(应用层协议:HTTP、HTTPS)-- 详解
|
13天前
|
存储 网络协议 Unix
【Linux 网络】网络编程套接字 -- 详解
【Linux 网络】网络编程套接字 -- 详解