机器学习 —— MNIST手写体识别(上)

简介: 机器学习 —— MNIST手写体识别

pytorch配置

Jupyter notebook

1、打开 Anaconda Prompt(anaconda 3)

2、在终端输入:conda create -n torch python=3.9

3、查看环境确认创建成功

4、激活虚拟环境:activate torch

5、创建内核:conda install -n torch ipykernel

8、检查安装是否成功

方法2

Pycharm

1、选择环境

2、选择 torch 环境解释器

3、测试

独显安装参考(就是先装好CUDA和cudnn,然后剩下的和集显一样,

先创建好虚拟环境和相应的 kernel,然后安装):Pytorch环境配置(anaconda安装+独显+CUDA+cuDNN)


数据集介绍

   给定数据集MNIST,Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

   MNIST是一个计算机视觉数据集,它包含各种手写数字图片0,1,2,…,9

   MNIST数据集包含:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。训练数据集和测试数据集都包含有一张手写的数字,和对应的标签,训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels;测试数据集的图片是 mnist.test.images ,训练数据集的标签是 mnist.test.labels每一张图片都是28*28像素(784个像素点),可以用点矩阵来表示每张图片

MNIST手写体识别

(一) 加包

import torch # 导入PyTorch库,用于构建和训练神经网络模型。
import torch.nn as nn # 导入PyTorch的神经网络模块,用于定义神经网络的层和模型。
import torch.nn.functional as F # 导入PyTorch的函数式接口模块,包含一些常用的非线性激活函数和损失函数。
import torch.optim as optim # 导入PyTorch的优化器模块,用于定义和实现不同的优化算法。
from torchvision import datasets, transforms # 从torchvision库中导入数据集和数据处理的功能。
import time # 导入Python的time模块,用于计时和时间相关的操作。
from matplotlib import pyplot as plt # 从matplotlib库中导入绘图功能,用于可视化数据和结果。

图1:加包

(二) 数据预处理

       这段代码主要涉及数据预处理和数据加载的过程。

       pipline_trainpipline_test是数据预处理的管道,通过transforms.Compose方法将多个预处理操作组合在一起。将transforms.RandomRotation添加到训练数据预处理中来进行随机旋转图片。使用transforms.Resize调整图片的尺寸,确保在将图片转化为Tensor格式之前调整好。将transforms.Normalize的参数从元组形式改为单个值形式,以便适应PyTorch 1.2.0版本及以上的要求。

       train_settest_set是使用MNIST数据集构建的训练集和测试集对象,参数说明如下:root=“./data”:数据集存储的根目录。train=True(对应train_set)或train=False(对应test_set):指定加载训练集还是测试集。download=True:如果数据集不存在,则自动从互联网上下载。transform=pipline_train(对应train_set)或transform=pipline_test(对应test_set):应用之前定义的数据预处理管道。

       trainloadertestloader是数据加载器,用于将数据集分批次加载到模型进行训练和测试。torch.utils.data.DataLoader用于构建数据加载器。batch_size控制每个批次的样本数量。将trainloader和testloader中的shuffle参数设置为布尔值,指示是否对数据进行洗牌操作。

pipline_train = transforms.Compose([
    #随机旋转图片
    transforms.RandomRotation(10),
    #将图片尺寸resize到32x32
    transforms.Resize((32, 32)),
    #将图片转化为Tensor格式
    transforms.ToTensor(),
    #正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)
    transforms.Normalize((0.1307,),(0.3081,))    
])
pipline_test = transforms.Compose([
    #将图片尺寸resize到32x32
    transforms.Resize((32, 32)),
    #将图片转化为Tensor格式
    transforms.ToTensor(),
    # 正则化
    transforms.Normalize((0.1307,),(0.3081,))
])
#下载数据集 使用MNIST数据集构建的训练集和测试集对象
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=pipline_train)
test_set = datasets.MNIST(root="./data", train=False, download=True, transform=pipline_test)
#加载数据集 torch.utils.data.DataLoader用于构建数据加载器
trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False)

图2:数据预处理

图3:下载数据集

机器学习 —— MNIST手写体识别(下)https://developer.aliyun.com/article/1507859?spm=a2c6h.13148508.setting.22.1b484f0e3Nvb2Q


相关文章
|
3月前
|
机器学习/深度学习 人工智能 文字识别
文本,文字扫描01,OCR文本识别技术展示,一个安卓App,一个简单的设计,文字识别可以应用于人工智能,机器学习,车牌识别,身份证识别,银行卡识别,PaddleOCR+SpringBoot+Andr
文本,文字扫描01,OCR文本识别技术展示,一个安卓App,一个简单的设计,文字识别可以应用于人工智能,机器学习,车牌识别,身份证识别,银行卡识别,PaddleOCR+SpringBoot+Andr
|
4月前
|
机器学习/深度学习 分布式计算 算法
在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)
【6月更文挑战第28天】在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)、数据规模与特性(大数据可能适合分布式算法或深度学习)、性能需求(准确性、速度、可解释性)、资源限制(计算与内存)、领域知识应用以及实验验证(交叉验证、模型比较)。迭代过程包括数据探索、模型构建、评估和优化,结合业务需求进行决策。
53 0
|
4月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
5月前
|
机器学习/深度学习 数据可视化 Python
机器学习 —— MNIST手写体识别(下)
机器学习 —— MNIST手写体识别(下)
|
5月前
|
机器学习/深度学习 数据可视化 数据处理
python 机器学习 sklearn——一起识别数字吧
python 机器学习 sklearn——一起识别数字吧
|
5月前
|
机器学习/深度学习 数据采集 存储
【机器学习】数据清洗之识别重复点
【机器学习】数据清洗之识别重复点
233 1
|
5月前
|
数据采集 机器学习/深度学习 算法
【机器学习】数据清洗之识别异常点
【机器学习】数据清洗之识别异常点
291 1
|
12天前
|
机器学习/深度学习 人工智能 自然语言处理
【MM2024】阿里云 PAI 团队图像编辑算法论文入选 MM2024
阿里云人工智能平台 PAI 团队发表的图像编辑算法论文在 MM2024 上正式亮相发表。ACM MM(ACM国际多媒体会议)是国际多媒体领域的顶级会议,旨在为研究人员、工程师和行业专家提供一个交流平台,以展示在多媒体领域的最新研究成果、技术进展和应用案例。其主题涵盖了图像处理、视频分析、音频处理、社交媒体和多媒体系统等广泛领域。此次入选标志着阿里云人工智能平台 PAI 在图像编辑算法方面的研究获得了学术界的充分认可。
【MM2024】阿里云 PAI 团队图像编辑算法论文入选 MM2024
|
8天前
|
机器学习/深度学习 人工智能 算法
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
玉米病害识别系统,本系统使用Python作为主要开发语言,通过收集了8种常见的玉米叶部病害图片数据集('矮花叶病', '健康', '灰斑病一般', '灰斑病严重', '锈病一般', '锈病严重', '叶斑病一般', '叶斑病严重'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。再使用Django搭建Web网页操作平台,实现用户上传一张玉米病害图片识别其名称。
22 0
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
|
16天前
|
机器学习/深度学习 算法 决策智能
【机器学习】揭秘深度学习优化算法:加速训练与提升性能
【机器学习】揭秘深度学习优化算法:加速训练与提升性能