ResNet残差网络Pytorch实现——对花的种类进行训练

简介: ResNet残差网络Pytorch实现——对花的种类进行训练

✌ 使用ResNet进行对花的种类进行训练

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm
# 加载设备,使用cpu还是显卡
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('using {} device'.format(device))
# 图片处理,对应train和验证集
data_transform={
    'train':transforms.Compose([transforms.RandomResizedCrop(224),  # 将图片裁剪为224*224
                                transforms.RandomHorizontalFlip(),  # 将图片随机反转
                                transforms.ToTensor(),  # 转化为ToTensor
                                # 进行标准化
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'val':transforms.Compose([transforms.Resize(256), # 调整图片大小
                              transforms.CenterCrop(224), # 中心裁剪224*224
                              transforms.ToTensor(),
                              transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
# os.getcwd()获得当前位置的绝对路径
# os.path.join()将两个路径进行拼接
img_path=os.path.join(os.getcwd(),'flower_data')
# 加载训练数据,同时对需要训练的图片进行处理
train_dataset=datasets.ImageFolder(root=os.path.join(img_path,'train'),
                                   transform=data_transform['train'])
# 加载验证集
val_dataset=datasets.ImageFolder(root=os.path.join(img_path,'val'),
                                transform=data_transform['val'])
# 定义每个训练批次的数据数量,对应每次训练16张图片
batch_size=16
# 训练数据的加载器
# 根据批次大小进行将数据进行分批
# 一般来说训练数据需要打乱,而验证集不需要
train_loader=torch.utils.data.DataLoader(train_dataset,
                                         batch_size,
                                         shuffle=True)
# 验证数据的加载器
val_loader=torch.utils.data.DataLoader(val_dataset,
                                       batch_size,
                                       shuffle=False)
# 训练和验证的数据大小
train_num=len(train_dataset)
val_num=len(val_dataset)
print('using {} images for training, {} images for validation.'.format(train_num,val_num))
# train_dataset.class_to_idx会返回'A':0,'B':1,'C':2,即每个类别对应的数值映射
flower_list=train_dataset.class_to_idx
# 将其逆置,为了预测时根据预测结果的分类找出对应的字符真实分类,如果不做,最终预测只知道是0,1,2这种,不知道花的真实类别
# data_dataset加载图片会根据图片所在的文件夹确定其分类
cla_dict=dict((value,key) for key,value in flower_list.items())
# 将字典转成json串
json_str=json.dumps(cla_dict,indent=4)
# 将json串写入到文件
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)
# 创建网络
net=resnet34()
# 模型参数路径
model_weight_path='./resnet34-pre.pth'
# 加载模型参数
net.load_state_dict(torch.load(model_weight_path,map_location=device))
# 取出全连接层进行替换,因为模型中默认是1000分类,而本题中是5分类,
# 所以要取出全连接层获得全连接层的输入层加上现在的新输出5分类,构建新的全连接层
in_channel=net.fc.in_features
net.fc=nn.Linear(in_channel,5)
net.to(device)
# 交叉熵损失函数
loss_function=nn.CrossEntropyLoss()
# 获得可以更新梯度的参数
params=[p for p in net.parameters() if p.requires_grad]
# 定义优化器
optimizer=optim.Adam(params,lr=0.0001)
# 训练轮数,将所有的数据训练epochs次
epochs=3
# 最好的准确度
best_acc=0
# 训练的模型参数路径
save_path='./resNet34.pth'
# 训练集的批数
train_steps=len(train_loader)
for epoch in range(epochs):
    # 开启训练模式
    net.train()
    running_loss=0
    train_bar=tqdm(train_loader)
    for data in train_bar:
        # 获得每个训练批的数据和标签
        images,labels=data
        optimizer.zero_grad()
        output=net(images.to(device))
        loss=loss_function(output,labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                  loss)
    # 开启验证模式
    net.eval()
    acc=0
    # 不需要进行梯度下降求导
    with torch.no_grad():
        val_bar=tqdm(val_loader)
        for data in val_bar:
            images,labels=data
            output=net(images.to(device))
            # touch.max()返回指定维度的最大值和该值所在的索引                                                
            y_pred=torch.max(output,dim=1)[1]
            # 计算预测正确的个数
            acc+=torch.eq(y_pred,labels.to(device)).sum().item()
            val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
    val_accurate=acc/val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
    # 如果当前的准确率>最高的准确率就将其取代,保存当前训练的参数
    if val_accurate>best_acc:
        best_acc=val_accurate
        torch.save(net.state_dict(),save_path)


目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
822 56
|
4月前
|
机器学习/深度学习 PyTorch 测试技术
从训练到推理:Intel Extension for PyTorch混合精度优化完整指南
PyTorch作为主流深度学习框架,凭借动态计算图和异构计算支持,广泛应用于视觉与自然语言处理。Intel Extension for PyTorch针对Intel硬件深度优化,尤其在GPU上通过自动混合精度(AMP)提升训练与推理性能。本文以ResNet-50在CIFAR-10上的实验为例,详解如何利用该扩展实现高效深度学习优化。
273 0
|
1月前
|
机器学习/深度学习 数据可视化 网络架构
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
PINNs训练难因多目标优化易失衡。通过设计硬约束网络架构,将初始与边界条件内嵌于模型输出,可自动满足约束,仅需优化方程残差,简化训练过程,提升稳定性与精度,适用于气候、生物医学等高要求仿真场景。
289 4
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
|
2月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
|
2月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
144 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
6月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
6月前
|
机器学习/深度学习 自然语言处理 PyTorch
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
|
6月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
6月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。

热门文章

最新文章

推荐镜像

更多