【基础实操】借用torch自带网络进行训练自己的图像数据

简介: 【基础实操】借用torch自带网络进行训练自己的图像数据

前言

  在本文里将为大家带来如何进行使用pytorch中的自带的深度网络进行训练自己的数据。本文讲解可分两部分,第一部分为大家介绍如何进行目录式的读取自己的数据;第二部分为大家介绍如何进行更改为其他网络进行调试。(alexnet\densenet\mnasnet\moblienet\resnet\shufflenet\squeezenet\vgg)

目录式读取

  由于大家在做图像分类的时候,一般是往把搜集到的同类图像放置在同一个文件夹内,因此我们在这里采用目录式读取自己制作的数据集进行训练网络。

数据组成:

image.png

  在这里我采用鲜花数据集为基础数据集并对此数据集进行修改。在鲜花数据集中我们确定总数据类别为5类,在训练集中每一类中的图像数为500涨,在测试集中的每一类的图像数为150张,对训练测试内的图像进行修改大小为224x224x3。

  参考pytorch官网示例,我们可以将示例修改进行如下修改:

ini

复制代码

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
train_root = './datas/train/'
test_root = './datas/test/'
# 将文件夹的内容载入dataset
train_dataset = torchvision.datasets.ImageFolder(root=train_root, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root=test_root, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

  这这一步的时候我们可捎带的将超参数进行设置一下,如下设置:

ini

复制代码

learning_rate = 0.1
batch_size = 64
epochs = 100
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

  由于训练集是需要训练而测试集不需要进行训练,那么可参考官网的示例分别对训练集的操作和测试集的操作保持不变如下:

scss

复制代码

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

更改网络

  在pytorch的自带网络大家可与自行将复制出来,这样可与避免因自己的更改导致后续因为再次使用出现不必要的BUG,pytorch自带的models路径为:

   envs\pytorch\Lib\site-packages\torchvision\models

image.png

  大家可与设置model为自己需要调用的网络,在这里我们以vgg网络系列的vgg11为例子为大家介绍如何进行训练网络。我们依旧保持官网示例中的SGD训练函数作为optimizer,然后将各参数导入到train和test中进行训练自己的数据。由于网络比较多,我在这里就不一一为大家介绍了。

  大家可与移步我的Github

image.png

css

复制代码

from myNets.vgg import vgg11  # 可更换为其他
if __name__ == "__main__":
    model = vgg11(num_classes=5)  # 可更换为其他
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)
    print("Done!")

拓展

  在上一期:【实操】涨点神器你还不会,快点进来学习Label Smooth我们介绍了Label Smooth操作,大家可与尝试自行更改进行综合比对测试使用不同的Label Smooth操作对结果的影响,也可更换其他的学习率的衰减函数进行测试。


相关文章
|
3月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
797 56
|
2月前
|
机器学习/深度学习 算法 调度
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
324 0
|
3月前
|
机器学习/深度学习 数据采集 传感器
【故障诊断】基于matlab BP神经网络电机数据特征提取与故障诊断研究(Matlab代码实现)
【故障诊断】基于matlab BP神经网络电机数据特征提取与故障诊断研究(Matlab代码实现)
136 0
|
4月前
|
数据采集 存储 算法
MyEMS 开源能源管理系统:基于 4G 无线传感网络的能源数据闭环管理方案
MyEMS 是开源能源管理领域的标杆解决方案,采用 Python、Django 与 React 技术栈,具备模块化架构与跨平台兼容性。系统涵盖能源数据治理、设备管理、工单流转与智能控制四大核心功能,结合高精度 4G 无线计量仪表,实现高效数据采集与边缘计算。方案部署灵活、安全性高,助力企业实现能源数字化与碳减排目标。
143 0
|
1月前
|
机器学习/深度学习 数据可视化 网络架构
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
PINNs训练难因多目标优化易失衡。通过设计硬约束网络架构,将初始与边界条件内嵌于模型输出,可自动满足约束,仅需优化方程残差,简化训练过程,提升稳定性与精度,适用于气候、生物医学等高要求仿真场景。
278 4
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
|
1月前
|
机器学习/深度学习 人工智能 算法
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
112 0
|
2月前
|
机器学习/深度学习 数据采集 运维
改进的遗传算法优化的BP神经网络用于电厂数据的异常检测和故障诊断
改进的遗传算法优化的BP神经网络用于电厂数据的异常检测和故障诊断
|
4月前
|
存储 监控 算法
基于 Python 跳表算法的局域网网络监控软件动态数据索引优化策略研究
局域网网络监控软件需高效处理终端行为数据,跳表作为一种基于概率平衡的动态数据结构,具备高效的插入、删除与查询性能(平均时间复杂度为O(log n)),适用于高频数据写入和随机查询场景。本文深入解析跳表原理,探讨其在局域网监控中的适配性,并提供基于Python的完整实现方案,优化终端会话管理,提升系统响应性能。
130 4
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
262 17
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
221 10
下一篇
oss云网关配置