ResNet残差网络Pytorch实现——cifar10数据集训练

简介: ResNet残差网络Pytorch实现——cifar10数据集训练

✌ 使用ResNet进行对cifar10数据集进行训练

import torchvision
import torch
from torchvision import transforms
import os
import json
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm
# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 每个批次的数据大小
batch_size=100
# 加载训练数据,不需要下载
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10', 
                                             train=True,
                                             download=False, 
                                             transform=data_transform)
# 训练数据的加载器
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True)
# 训练数据大小 
train_num=len(train_dataset)
print('using {} images for training.'.format(train_num))
# 预测结果与真实分类的映射
cifar10_list=train_dataset.class_to_idx
cla_dict=dict((value,key) for key,value in cifar10_list.items())
json_str=json.dumps(cla_dict,indent=10)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)
# 构建网络
net=resnet34()
# 加载模型参数
model_weight_path='./resnet34-pre.pth'
net.load_state_dict(torch.load(model_weight_path,map_location=device))
# 将每个参数置为False,反向传播时不会进行梯度更新
for param in net.parameters():
    param.requires_grad=False
# 修改全连接层
in_channel=net.fc.in_features
net.fc=nn.Linear(in_channel,10)
net.to(device)
# 交叉熵损失函数
loss_function=nn.CrossEntropyLoss()
# 获得需要训练的参数
params=[p for p in net.parameters() if p.requires_grad]
# 优化器
optimizer=optim.Adam(params,lr=0.0001)
epochs=1
loss_sum=999
save_path='./resNet34_cifar10.pth'
train_steps=len(train_loader)
# 开始训练,所有数据只训练1次
for epoch in range(epochs):
    net.train()
    running_loss=0
    train_bar=tqdm(train_loader)
    # 训练集总共50000张图片,我设置的每批数据是100,所以对应是500*100
    # 循环500次,每次训练的数据为100张
    for data in train_bar:
        images,labels=data
        optimizer.zero_grad()
        output=net(images.to(device))
        loss=loss_function(output,labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
    # 保存最好的模型参数
    if running_loss/train_steps<loss_sum:
        loss_sum=running_loss/train_steps
        torch.save(net.state_dict(),save_path)


目录
相关文章
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】32. 卷积神经网络之稠密连接网络(DenseNet)介绍及其Pytorch实现
【从零开始学习深度学习】32. 卷积神经网络之稠密连接网络(DenseNet)介绍及其Pytorch实现
|
7天前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch框架和MNIST数据集
6月更文挑战20天
40 2
|
19天前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
|
19天前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
9天前
|
机器学习/深度学习 人工智能 算法
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
昆虫识别系统,使用Python作为主要开发语言。通过TensorFlow搭建ResNet50卷积神经网络算法(CNN)模型。通过对10种常见的昆虫图片数据集('蜜蜂', '甲虫', '蝴蝶', '蝉', '蜻蜓', '蚱蜢', '蛾', '蝎子', '蜗牛', '蜘蛛')进行训练,得到一个识别精度较高的H5格式模型文件,然后使用Django搭建Web网页端可视化操作界面,实现用户上传一张昆虫图片识别其名称。
140 7
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】29.卷积神经网络之GoogLeNet模型介绍及用Pytorch实现GoogLeNet模型【含完整代码】
【从零开始学习深度学习】29.卷积神经网络之GoogLeNet模型介绍及用Pytorch实现GoogLeNet模型【含完整代码】
|
5天前
|
并行计算 PyTorch 程序员
老程序员分享:Pytorch入门之Siamese网络
老程序员分享:Pytorch入门之Siamese网络