ResNet残差网络Pytorch实现——对花的种类进行单数据预测

简介: ResNet残差网络Pytorch实现——对花的种类进行单数据预测

✌ 使用ResNet进行对花的种类进行单数据预测

import os
import json
import torch
from torchvision import transforms
from PIL import Image
# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 预测图片的路径
img_path='13290033_ebd7c7abba_n.jpg'
# 加载图片
img=Image.open(img_path)
# 将图片进行处理,返回的是tensor
img=data_transform(img)
# 将数据升维,图片是三维,而训练时是4维,因为训练时第一个维度为每个批次的训练数据大小
# 本次预测一个图片,就要升维,变成1,表明该批次图片有1个
img=torch.unsqueeze(img,dim=0)
# 读取预测结果和真实分类的映射
json_path='./class_indices.json'
json_file=open(json_path,'r')
# 加载成字典 0:'A',1:'B',2:'C'
class_indict=json.load(json_file)
# 创建网络
model=resnet34(num_classes=5).to(device)
# 加载模型训练好的参数
weitcht_path='./resNet34.pth'
model.load_state_dict(torch.load(weitcht_path,map_location=device))
# 开启验证模式
model.eval()
# 不需要求导
with torch.no_grad():
    # 每个数据对应输出,(1,5)维度,将其降维,直接是(5,)
    output=torch.squeeze(model(img.to(device))).cpu()
    # 如果是预测,完全可以用touch.max,返回最大值索引,但是下面为了输出预测的概率,就要将其标准化,概率和为1
    # 按道理说应该是dim=1,每行的所有列,但是现在不是二维,所以要用0
    predict=torch.softmax(output,dim=0)
    # 返回最大值索引,dim=0和上面的同理
    predict_cla=torch.argmax(predict,dim=0).item()
print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
print(print_res)


目录
相关文章
|
16天前
|
机器学习/深度学习 算法 调度
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
102 0
|
2月前
|
机器学习/深度学习 数据采集 传感器
【故障诊断】基于matlab BP神经网络电机数据特征提取与故障诊断研究(Matlab代码实现)
【故障诊断】基于matlab BP神经网络电机数据特征提取与故障诊断研究(Matlab代码实现)
|
3月前
|
数据采集 存储 算法
MyEMS 开源能源管理系统:基于 4G 无线传感网络的能源数据闭环管理方案
MyEMS 是开源能源管理领域的标杆解决方案,采用 Python、Django 与 React 技术栈,具备模块化架构与跨平台兼容性。系统涵盖能源数据治理、设备管理、工单流转与智能控制四大核心功能,结合高精度 4G 无线计量仪表,实现高效数据采集与边缘计算。方案部署灵活、安全性高,助力企业实现能源数字化与碳减排目标。
66 0
|
4月前
|
Python
LBA-ECO CD-32 通量塔网络数据汇编,巴西亚马逊:1999-2006,V2
该数据集汇集了1999年至2006年间巴西亚马逊地区九座观测塔的碳和能量通量、气象、辐射等多类数据,涵盖小时至月度时间步长。作为第二版汇编,数据经过协调与质量控制,扩展了第一版内容,并新增生态系统呼吸等相关计算数据,支持综合研究与模型合成。数据以36个制表符分隔文本文件形式提供,配套PDF说明文件,适用于生态与气候研究。引用来源为Restrepo-Coupe等人(2021)。
50 1
|
1月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
|
24天前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
|
29天前
|
机器学习/深度学习 数据采集 运维
改进的遗传算法优化的BP神经网络用于电厂数据的异常检测和故障诊断
改进的遗传算法优化的BP神经网络用于电厂数据的异常检测和故障诊断
|
3月前
|
存储 监控 算法
基于 Python 跳表算法的局域网网络监控软件动态数据索引优化策略研究
局域网网络监控软件需高效处理终端行为数据,跳表作为一种基于概率平衡的动态数据结构,具备高效的插入、删除与查询性能(平均时间复杂度为O(log n)),适用于高频数据写入和随机查询场景。本文深入解析跳表原理,探讨其在局域网监控中的适配性,并提供基于Python的完整实现方案,优化终端会话管理,提升系统响应性能。
82 4
|
4月前
|
开发者
鸿蒙仓颉语言开发教程:网络请求和数据解析
本文介绍了在仓颉开发语言中实现网络请求的方法,以购物应用的分类列表为例,详细讲解了从权限配置、发起请求到数据解析的全过程。通过示例代码,帮助开发者快速掌握如何在网络请求中处理数据并展示到页面上,减少开发中的摸索成本。
鸿蒙仓颉语言开发教程:网络请求和数据解析
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络

热门文章

最新文章

推荐镜像

更多