DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率

简介: DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率

输出结果

image.png

image.png

核心代码

#DL之NN:基于sklearn自带手写数字图片识别数据集+自定义NN类(三层64→100→10)实现97.5%准确率

#输入64+1(偏置)个神经元,隐藏层神经元个数可以自定义,输出层10个神经元

import numpy as np  

from sklearn.datasets import load_digits              #sklearn自带数据集

from sklearn.metrics import confusion_matrix, classification_report  

from sklearn.preprocessing import LabelBinarizer      #标签二值化

from sklearn.cross_validation import train_test_split  #将数据切分分训练数据和测试数据

import matplotlib.pyplot as plt

def sigmoid(x):

   return 1/(1+np.exp(-x))

def dsigmoid(x):

   return x*(1-x)

……

             

   def predict(self,x):  #预测函数,也需要先添加偏置

       #添加偏置,最初的数据64上需额外加入偏置列

       temp=np.ones(x.shape[0]+1)   #

       temp[0:-1]=x #该矩阵的0列到-1列

       x=temp   #通过转换行没有变,但是多了一列

       x=np.atleast_2d(x) #转为2维数据

       L1=sigmoid(np.dot(x,self.V)) #隐藏层输出

       L2=sigmoid(np.dot(L1,self.W)) #输出层输出

       return L2

 

digits = load_digits()  #下载数据集

X = digits.data         #输入数据

y = digits.target       #标签

#输入数据归一化:把最初的数据都变为[0~1]之间的数据

X -= X.min()  

X /= X.max()  

nn = NeuralNetwork([64, 100, 10]) #构建神经网络,神经元个数

X_train, X_test, y_train, y_test = train_test_split(X, y)  #分割数据,75%为训练25%为测试

#对标签二值化,将输出变为神经网络的风格:比如若输出3→0001000000

labels_train = LabelBinarizer().fit_transform(y_train)  

labels_test = LabelBinarizer().fit_transform(y_test)

print ("start")  

nn.train(X_train, labels_train, epochs=30000)  

print ("over")  


相关文章
|
消息中间件 监控 数据挖掘
NineData:从Kafka到ClickHouse的数据同步解决方案
NineData 提供了强大的数据转换和映射功能,以解决 Kafka 和 ClickHouse 之间的格式和结构差异,确保数据在同步过程中的一致性和准确性。
768 2
NineData:从Kafka到ClickHouse的数据同步解决方案
|
数据采集 数据挖掘
【数据挖掘】利用sklearn进行数据预处理讲解与实战(超详细 附源码)
【数据挖掘】利用sklearn进行数据预处理讲解与实战(超详细 附源码)
420 0
|
关系型数据库 MySQL PHP
LAMP架构及搭建LAMP+Discuz论坛
LAMP架构及搭建LAMP+Discuz论坛
657 0
|
数据采集 存储 SQL
数据中台全景架构及模块解析!一文入门中台架构师!
数据中台全景架构及模块解析!包括数据采集、数据存储、数据开发处理、数据资产管理、数据质量和安全、数据服务。一文入门中台架构师!
|
机器学习/深度学习 算法 PyTorch
YOLO如何入门?
YOLO如何入门?
|
数据安全/隐私保护
OAuth2.0实战案例
OAuth2.0实战案例
309 0
OAuth2.0实战案例
|
Java
Java“NullPointerException”解决
Java中的“NullPointerException”是常见的运行时异常,发生在尝试使用null对象实例的方法或字段时。解决方法包括:1. 检查变量是否被正确初始化;2. 使用Optional类避免null值;3. 增加空指针检查逻辑。
2030 2
|
机器学习/深度学习 算法 数据挖掘
Python数据分析革命:Scikit-learn库,让机器学习模型训练与评估变得简单高效!
在数据驱动时代,Python 以强大的生态系统成为数据科学的首选语言,而 Scikit-learn 则因简洁的 API 和广泛的支持脱颖而出。本文将指导你使用 Scikit-learn 进行机器学习模型的训练与评估。首先通过 `pip install scikit-learn` 安装库,然后利用内置数据集进行数据准备,选择合适的模型(如逻辑回归),并通过交叉验证评估其性能。最终,使用模型对新数据进行预测,简化整个流程。无论你是新手还是专家,Scikit-learn 都能助你一臂之力。
482 8
|
Java
IDEA的fxml打开Scene Builder后空白! Scene Builder下载依赖后还是空白不显示 无论如何都不显示,网上的教程试过来了遍还是不显示
本文提供了三种方法来解决IDEA中fxml文件在Scene Builder中打开后显示空白的问题:检查JavaFX是否安装、切换IDEA版本、下载Scene Builder插件。
909 1
|
人工智能 文字识别 自然语言处理
Nougat:一种用于科学文档OCR的Transformer 模型
随着人工智能领域的不断进步,其子领域,包括自然语言处理,自然语言生成,计算机视觉等,由于其广泛的用例而迅速获得了大量的普及。光学字符识别(OCR)是计算机视觉中一个成熟且被广泛研究的领域。它有许多用途,如文档数字化、手写识别和场景文本识别。数学表达式的识别是OCR在学术研究中受到广泛关注的一个领域。
546 0