【基础实操】借用torch自带网络进行训练自己的图像数据

简介: 【基础实操】借用torch自带网络进行训练自己的图像数据

前言

  在本文里将为大家带来如何进行使用pytorch中的自带的深度网络进行训练自己的数据。本文讲解可分两部分,第一部分为大家介绍如何进行目录式的读取自己的数据;第二部分为大家介绍如何进行更改为其他网络进行调试。(alexnet\densenet\mnasnet\moblienet\resnet\shufflenet\squeezenet\vgg)

目录式读取

  由于大家在做图像分类的时候,一般是往把搜集到的同类图像放置在同一个文件夹内,因此我们在这里采用目录式读取自己制作的数据集进行训练网络。

数据组成:

image.png

  在这里我采用鲜花数据集为基础数据集并对此数据集进行修改。在鲜花数据集中我们确定总数据类别为5类,在训练集中每一类中的图像数为500涨,在测试集中的每一类的图像数为150张,对训练测试内的图像进行修改大小为224x224x3。

  参考pytorch官网示例,我们可以将示例修改进行如下修改:

ini

复制代码

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
train_root = './datas/train/'
test_root = './datas/test/'
# 将文件夹的内容载入dataset
train_dataset = torchvision.datasets.ImageFolder(root=train_root, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root=test_root, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

  这这一步的时候我们可捎带的将超参数进行设置一下,如下设置:

ini

复制代码

learning_rate = 0.1
batch_size = 64
epochs = 100
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

  由于训练集是需要训练而测试集不需要进行训练,那么可参考官网的示例分别对训练集的操作和测试集的操作保持不变如下:

scss

复制代码

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

更改网络

  在pytorch的自带网络大家可与自行将复制出来,这样可与避免因自己的更改导致后续因为再次使用出现不必要的BUG,pytorch自带的models路径为:

   envs\pytorch\Lib\site-packages\torchvision\models

image.png

  大家可与设置model为自己需要调用的网络,在这里我们以vgg网络系列的vgg11为例子为大家介绍如何进行训练网络。我们依旧保持官网示例中的SGD训练函数作为optimizer,然后将各参数导入到train和test中进行训练自己的数据。由于网络比较多,我在这里就不一一为大家介绍了。

  大家可与移步我的Github

image.png

css

复制代码

from myNets.vgg import vgg11  # 可更换为其他
if __name__ == "__main__":
    model = vgg11(num_classes=5)  # 可更换为其他
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)
    print("Done!")

拓展

  在上一期:【实操】涨点神器你还不会,快点进来学习Label Smooth我们介绍了Label Smooth操作,大家可与尝试自行更改进行综合比对测试使用不同的Label Smooth操作对结果的影响,也可更换其他的学习率的衰减函数进行测试。


相关文章
|
3天前
|
机器学习/深度学习 存储 监控
数据分享|Python卷积神经网络CNN身份识别图像处理在疫情防控下口罩识别、人脸识别
数据分享|Python卷积神经网络CNN身份识别图像处理在疫情防控下口罩识别、人脸识别
11 0
|
3天前
|
数据可视化 数据挖掘
【视频】复杂网络分析CNA简介与R语言对婚礼数据聚类社区检测和可视化|数据分享
【视频】复杂网络分析CNA简介与R语言对婚礼数据聚类社区检测和可视化|数据分享
10 2
|
1天前
|
人工智能 数据可视化
【数据分享】维基百科Wiki负面有害评论(网络暴力)文本数据多标签分类挖掘可视化
【数据分享】维基百科Wiki负面有害评论(网络暴力)文本数据多标签分类挖掘可视化
12 2
|
2天前
|
数据可视化 数据挖掘
R语言用igraph对上海公交巴士路线数据进行复杂网络、网络图可视化
R语言用igraph对上海公交巴士路线数据进行复杂网络、网络图可视化
|
2天前
|
机器学习/深度学习 算法 TensorFlow
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
|
4天前
|
存储 SQL 安全
网络安全与信息安全:保护数据的关键策略
【4月更文挑战第24天】 在数字化时代,数据成为了新的货币。然而,随着网络攻击的日益猖獗,如何确保信息的安全和隐私成为了一个亟待解决的问题。本文将深入探讨网络安全漏洞的概念、加密技术的重要性以及提升安全意识的必要性,旨在为读者提供一套综合性的网络安全防护策略。通过对这些关键知识点的分享,我们希望能够增强个人和组织在面对网络威胁时的防御能力。
|
5天前
|
安全 JavaScript 前端开发
第十六届山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题—B模块安全事件响应/网络安全数据取证/应用安全
该内容描述了一次网络安全演练,包括七个部分:Linux渗透提权、内存取证、页面信息发现、数字取证调查、网络安全应急响应、Python代码分析和逆向分析。参与者需在模拟环境中收集Flag值,涉及任务如获取服务器信息、提权、解析内存片段、分析网络数据包、处理代码漏洞、解码逆向操作等。每个部分都列出了若干具体任务,要求提取或生成特定信息作为Flag提交。
9 0
|
5天前
|
安全 测试技术 网络安全
2024年山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题-C安全事件响应/网络安全数据取证/应用安全
B模块涵盖安全事件响应和应用安全,包括Windows渗透测试、页面信息发现、Linux系统提权及网络安全应急响应。在Windows渗透测试中,涉及系统服务扫描、DNS信息提取、管理员密码、.docx文件名及内容、图片中单词等Flag值。页面信息发现任务包括服务器端口、主页Flag、脚本信息、登录成功信息等。Linux系统渗透需收集SSH端口号、主机名、内核版本,并实现提权获取root目录内容和密码。网络安全应急响应涉及删除后门用户、找出ssh后门时间、恢复环境变量文件、识别修改的bin文件格式及定位挖矿病毒钱包地址。
11 0
|
5天前
|
安全 测试技术 Linux
2024年山东省职业院校技能大赛中职组 “网络安全”赛项竞赛试题-A模块安全事件响应/网络安全数据取证/应用安全
该内容描述了一个网络安全挑战,涉及Windows和Linux系统的渗透测试以及隐藏信息探索和内存取证。挑战包括使用Kali Linux对Windows Server进行服务扫描、DNS信息提取、密码获取、文件名和内容查找等。对于Linux系统,任务包括收集服务器信息、提权并查找特定文件内容和密码。此外,还有对Server2007网站的多步骤渗透,寻找登录界面和页面中的隐藏FLAG。最后,需要通过FTP获取win20230306服务器的内存片段,从中提取密码、地址、主机名、挖矿程序信息和浏览器搜索关键词。
8 0
|
5天前
|
安全 测试技术 网络安全
2024年甘肃省职业院校技能大赛中职组 “网络安全”赛项竞赛样题-C模块安全事件响应/网络安全数据取证/应用安全
涉及安全事件响应和应用安全测试。需使用Kali对Windows Server2105进行渗透测试,包括服务扫描、DNS信息提取、管理员密码、文件名与内容、图片中单词等。另外,需收集win20230305的服务器端口、页面信息、脚本、登录后信息等。在Linux Server2214上,要获取SSH端口、主机名、内核版本并进行提权操作。网络安全响应针对Server2228,涉及删除后门用户、查找SSH后门时间、恢复环境变量、识别篡改文件格式和矿池钱包地址。最后,对lin20230509进行网站渗透,获取端口号、数据库服务版本、脚本创建时间、页面路径、内核版本和root目录下的flag文件内容
6 0