机器学习之卷积神经网络使用cifar10数据集和alexnet网络模型训练分类模型

简介: 机器学习之卷积神经网络使用cifar10数据集和alexnet网络模型训练分类模型

使用cifar10数据集和alexnet网络模型训练分类模型

下载cifar10数据集

在这里插入图片描述

代码:

import torchvision
import torch
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Resize(224)]
)

train_set = torchvision.datasets.CIFAR10(root='./',download=False,train=True,transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./',download=False,train=False,transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=8,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=8,shuffle=True)

class Alexnet(torch.nn.Module):  #1080 2080
    def __init__(self,num_classes=10):
        super(Alexnet,self).__init__()
        net = torchvision.models.alexnet(pretrained=False)  #迁移学习
        net.classifier = torch.nn.Sequential()
        self.features = net
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.3),
            torch.nn.Linear(256 * 6 * 6, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(4096, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(4096, num_classes),
        )
    def forward(self,x):
        x = self.features(x)
        x = x.view(x.size(0),-1)
        x = self.classifier(x)

        return x
device = torch.device('cpu')
net = Alexnet().to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(net.parameters(),lr=0.001)

net.train()
for epoch in range(10):
    for step,(x,y) in enumerate(train_loader):  # 28*28*1  32*32*3
        x,y = x.to(device),y.to(device)
        output = net(x)
        loss = loss_func(output,y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    print("epoch:",epoch,'loss:',loss)
目录
相关文章
|
1天前
|
机器学习/深度学习 人工智能 编解码
Backbone往事 | AlexNet~EfficientNet,10多个网络演变铺满了炼丹师们的青葱岁月
Backbone往事 | AlexNet~EfficientNet,10多个网络演变铺满了炼丹师们的青葱岁月
|
14天前
|
机器学习/深度学习 数据采集 算法
GEE机器学习——利用支持向量机SVM进行土地分类和精度评定
GEE机器学习——利用支持向量机SVM进行土地分类和精度评定
7 0
|
15天前
|
机器学习/深度学习 编解码 算法
YOLOv8改进 | 主干篇 | 低照度增强网络PE-YOLO改进主干(改进暗光条件下的物体检测模型)
YOLOv8改进 | 主干篇 | 低照度增强网络PE-YOLO改进主干(改进暗光条件下的物体检测模型)
16 0
|
15天前
|
机器学习/深度学习 编解码 网络架构
YOLOv8改进 | 主干篇 | 华为移动端模型Ghostnetv2改进特征提取网络
YOLOv8改进 | 主干篇 | 华为移动端模型Ghostnetv2改进特征提取网络
23 0
|
15天前
|
机器学习/深度学习 测试技术
YOLOv8改进 | 主干篇 | 华为移动端模型Ghostnetv1改进特征提取网络
YOLOv8改进 | 主干篇 | 华为移动端模型Ghostnetv1改进特征提取网络
20 0
|
15天前
|
机器学习/深度学习 人工智能 API
人工智能应用工程师技能提升系列2、——TensorFlow2——keras高级API训练神经网络模型
人工智能应用工程师技能提升系列2、——TensorFlow2——keras高级API训练神经网络模型
14 0
|
16天前
|
网络协议 Linux
Linux下的网络编程——B/S模型HTTP(四)
Linux下的网络编程——B/S模型HTTP(四)
22 0
|
16天前
|
网络协议 大数据 Linux
Linux下的网络编程——C/S模型 UDP(三)
Linux下的网络编程——C/S模型 UDP(三)
38 0
Linux下的网络编程——C/S模型 UDP(三)
|
16天前
|
网络协议 关系型数据库 MySQL
Linux下的网络编程——C/S模型TCP(二)
Linux下的网络编程——C/S模型TCP(二)
22 0
Linux下的网络编程——C/S模型TCP(二)
|
17天前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)

热门文章

最新文章

相关产品