ResNet残差网络Pytorch实现——cifar10数据集训练

简介: ResNet残差网络Pytorch实现——cifar10数据集训练

✌ 使用ResNet进行对cifar10数据集进行训练

import torchvision
import torch
from torchvision import transforms
import os
import json
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm
# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 每个批次的数据大小
batch_size=100
# 加载训练数据,不需要下载
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10', 
                                             train=True,
                                             download=False, 
                                             transform=data_transform)
# 训练数据的加载器
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True)
# 训练数据大小 
train_num=len(train_dataset)
print('using {} images for training.'.format(train_num))
# 预测结果与真实分类的映射
cifar10_list=train_dataset.class_to_idx
cla_dict=dict((value,key) for key,value in cifar10_list.items())
json_str=json.dumps(cla_dict,indent=10)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)
# 构建网络
net=resnet34()
# 加载模型参数
model_weight_path='./resnet34-pre.pth'
net.load_state_dict(torch.load(model_weight_path,map_location=device))
# 将每个参数置为False,反向传播时不会进行梯度更新
for param in net.parameters():
    param.requires_grad=False
# 修改全连接层
in_channel=net.fc.in_features
net.fc=nn.Linear(in_channel,10)
net.to(device)
# 交叉熵损失函数
loss_function=nn.CrossEntropyLoss()
# 获得需要训练的参数
params=[p for p in net.parameters() if p.requires_grad]
# 优化器
optimizer=optim.Adam(params,lr=0.0001)
epochs=1
loss_sum=999
save_path='./resNet34_cifar10.pth'
train_steps=len(train_loader)
# 开始训练,所有数据只训练1次
for epoch in range(epochs):
    net.train()
    running_loss=0
    train_bar=tqdm(train_loader)
    # 训练集总共50000张图片,我设置的每批数据是100,所以对应是500*100
    # 循环500次,每次训练的数据为100张
    for data in train_bar:
        images,labels=data
        optimizer.zero_grad()
        output=net(images.to(device))
        loss=loss_function(output,labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
    # 保存最好的模型参数
    if running_loss/train_steps<loss_sum:
        loss_sum=running_loss/train_steps
        torch.save(net.state_dict(),save_path)


目录
相关文章
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
99 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
1月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
63 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
1月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
52 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
机器学习/深度学习 算法 TensorFlow
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
学习率是深度学习中的关键超参数,它影响模型的训练进度和收敛性,过大或过小的学习率都会对网络训练产生负面影响,需要通过适当的设置和调整策略来优化。
309 0
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
|
7天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【10月更文挑战第40天】在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术以及安全意识等方面的知识,帮助读者更好地了解网络安全的重要性,并提供一些实用的技巧和建议,以保护个人和组织的信息安全。
29 6
|
1天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术和安全意识等方面的知识,并提供一些实用的技巧和建议,帮助读者更好地保护自己的网络安全和信息安全。
|
2天前
|
安全 算法 网络协议
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字时代,网络安全和信息安全已经成为了我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术和安全意识等方面的内容,帮助读者更好地了解网络安全的重要性和应对措施。通过阅读本文,您将了解到网络安全的基本概念、常见的网络安全漏洞、加密技术的原理和应用以及如何提高个人和组织的网络安全意识。
|
4天前
|
存储 安全 算法
网络安全与信息安全:漏洞、加密与意识的三重防线
在数字时代的浪潮中,网络安全与信息安全成为维护数据完整性、确保个人隐私和企业资产安全的基石。本文将深入探讨网络漏洞的成因、加密技术的应用以及安全意识的培养,旨在通过技术与教育的结合,构建起一道坚固的防御体系。我们将从实际案例出发,分析常见的网络安全威胁,揭示如何通过加密算法保护数据安全,并强调提升个人和组织的安全意识在防范网络攻击中的重要性。
|
1天前
|
监控 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全已成为全球关注的焦点。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议来保护个人和组织的数据安全。我们将从网络安全漏洞的识别和防范开始,然后介绍加密技术的原理和应用,最后强调安全意识在维护网络安全中的关键作用。无论你是个人用户还是企业管理者,这篇文章都将为你提供有价值的信息和指导。
|
1天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全成为了我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术以及安全意识等方面的内容,帮助读者更好地了解网络安全和信息安全的重要性,并提供一些实用的技巧和方法来保护个人信息和数据安全。

热门文章

最新文章