ResNet残差网络Pytorch实现——cifar10数据集训练

简介: ResNet残差网络Pytorch实现——cifar10数据集训练

✌ 使用ResNet进行对cifar10数据集进行训练

import torchvision
import torch
from torchvision import transforms
import os
import json
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm
# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 每个批次的数据大小
batch_size=100
# 加载训练数据,不需要下载
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10', 
                                             train=True,
                                             download=False, 
                                             transform=data_transform)
# 训练数据的加载器
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True)
# 训练数据大小 
train_num=len(train_dataset)
print('using {} images for training.'.format(train_num))
# 预测结果与真实分类的映射
cifar10_list=train_dataset.class_to_idx
cla_dict=dict((value,key) for key,value in cifar10_list.items())
json_str=json.dumps(cla_dict,indent=10)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)
# 构建网络
net=resnet34()
# 加载模型参数
model_weight_path='./resnet34-pre.pth'
net.load_state_dict(torch.load(model_weight_path,map_location=device))
# 将每个参数置为False,反向传播时不会进行梯度更新
for param in net.parameters():
    param.requires_grad=False
# 修改全连接层
in_channel=net.fc.in_features
net.fc=nn.Linear(in_channel,10)
net.to(device)
# 交叉熵损失函数
loss_function=nn.CrossEntropyLoss()
# 获得需要训练的参数
params=[p for p in net.parameters() if p.requires_grad]
# 优化器
optimizer=optim.Adam(params,lr=0.0001)
epochs=1
loss_sum=999
save_path='./resNet34_cifar10.pth'
train_steps=len(train_loader)
# 开始训练,所有数据只训练1次
for epoch in range(epochs):
    net.train()
    running_loss=0
    train_bar=tqdm(train_loader)
    # 训练集总共50000张图片,我设置的每批数据是100,所以对应是500*100
    # 循环500次,每次训练的数据为100张
    for data in train_bar:
        images,labels=data
        optimizer.zero_grad()
        output=net(images.to(device))
        loss=loss_function(output,labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
    # 保存最好的模型参数
    if running_loss/train_steps<loss_sum:
        loss_sum=running_loss/train_steps
        torch.save(net.state_dict(),save_path)


目录
相关文章
|
8月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
1258 56
|
9月前
|
机器学习/深度学习 PyTorch 测试技术
从训练到推理:Intel Extension for PyTorch混合精度优化完整指南
PyTorch作为主流深度学习框架,凭借动态计算图和异构计算支持,广泛应用于视觉与自然语言处理。Intel Extension for PyTorch针对Intel硬件深度优化,尤其在GPU上通过自动混合精度(AMP)提升训练与推理性能。本文以ResNet-50在CIFAR-10上的实验为例,详解如何利用该扩展实现高效深度学习优化。
468 0
|
12月前
|
机器学习/深度学习 存储 算法
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
反向传播算法虽是深度学习基石,但面临内存消耗大和并行扩展受限的问题。近期,牛津大学等机构提出NoProp方法,通过扩散模型概念,将训练重塑为分层去噪任务,无需全局前向或反向传播。NoProp包含三种变体(DT、CT、FM),具备低内存占用与高效训练优势,在CIFAR-10等数据集上达到与传统方法相当的性能。其层间解耦特性支持分布式并行训练,为无梯度深度学习提供了新方向。
658 1
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
|
6月前
|
机器学习/深度学习 数据可视化 网络架构
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
PINNs训练难因多目标优化易失衡。通过设计硬约束网络架构,将初始与边界条件内嵌于模型输出,可自动满足约束,仅需优化方程残差,简化训练过程,提升稳定性与精度,适用于气候、生物医学等高要求仿真场景。
689 4
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
|
机器学习/深度学习 PyTorch 测试技术
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 ResNet50图像分类
本实验基于PyTorch,在昇腾平台上使用ResNet50对CIFAR10数据集进行图像分类训练。内容涵盖ResNet50的网络架构、残差模块分析及训练代码详解。通过端到端的实战讲解,帮助读者理解如何在深度学习中应用ResNet50模型,并实现高效的图像分类任务。实验包括数据预处理、模型搭建、训练与测试等环节,旨在提升模型的准确率和训练效率。
750 54
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
|
机器学习/深度学习 PyTorch 算法框架/工具
ResNet代码复现+超详细注释(PyTorch)
ResNet代码复现+超详细注释(PyTorch)
5470 1
|
机器学习/深度学习 数据采集 PyTorch
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
1234 2

热门文章

最新文章

推荐镜像

更多