手写数字识别基本思路

简介: 手写数字识别基本思路

问题

什么是MNIST?如何使用Pytorch实现手写数字识别?如何进行手写数字对模型进行检验?


方法

mnist数据集

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。


使用Pytorch实现手写数字识别

1.进行数据预处理对于MNIST数据集,可以通过torchvision中的datasets进行下载。

root (string):表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录。

train (bool, optional):如果为True,则从training.pt创建数据集,否则从test.pt创建数据集。

download (bool, optional):如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载。

transform (callable, optional):接收PIL图片并返回转换后版本图片的转换函数。

bat = 128
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(0.1307, 0.3081)  # (均值,方差)
])  # Compoes 两个操作合为一个
train_ds = datasets.MNIST(root='data', download=False, train=True,
                         transform=transform)
train_ds, val_ds = torch.utils.data.random_split(train_ds, [50000, 10000])
test_ds = datasets.MNIST(root='data', download=True, train=False,
                        transform=transform)
train_loader = DataLoader(dataset=train_ds, batch_size=bat, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=bat)
test_loader = DataLoader(dataset=test_ds, batch_size=bat)


2.构建模型

class MyNet(nn.Module):
   def __init__(self) -> None:
       super().__init__()

       self.flatten = nn.Flatten()  # 将28*28的图像拉伸为784维向量
       # 第一个全连接层Full Connection(FC)
       self.fc1 = nn.Linear(in_features=784,
                            out_features=256)
       self.fc2 = nn.Linear(in_features=256,
                            out_features=128)
       self.fc3 = nn.Linear(in_features=128,
                            out_features=10)

   def forward(self, x):
       x = self.flatten(x)
       x = torch.relu(self.fc1(x))
       x = torch.relu(self.fc2(x))
       out = torch.relu(self.fc3(x))
       return out

构建一个三层的神经网络MNIST数据集中的图片都是28×28大小的,而且是灰度图。而全连接神经网络的输入要是一个行向量,所以我们要把28×28的矩阵转换成28×28=764的行向量,作为神经网络的输入


3.优化器的选择,参数设置

使用优化器和损失函数。优化器选择SGD,SGD随机梯度下降,lr学习率取值0.2最优,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生。

optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#lr学习率,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生
#损失函数
#衡量yy_hat之间的差异
loss_fn=nn.CrossEntropyLoss()


4.对模型进行训练测试,网络的输入,输入尺寸B*C*H*W B是batch,一个batch一个batch交给网络处理,x=torch.rand(size=(128,1,28,28)),基于loss信息利用优化器从后向前更新网络全部参数。

def train(dataloader, net, loss_fn, optimizer, epoch):
   size = len(dataloader.dataset)
   corrent = 0
   epoch_loss = 0.0
   batch_num = len(dataloader)
   net.train()

   # 一个batch一个batch的训练网络
   for batch_idx, (X, y) in enumerate(dataloader):
       pred = net(X)

       # 衡量y与y_hat之间的loss
       # y:128, pred:128x10 CrossEntropyloss
       loss = loss_fn(pred, y)

       # 基于loss信息利用优化器从后向前更新网络全部参数 <---
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       epoch_loss += loss.item()
       corrent += (pred.argmax(1) == y).type(torch.float).sum().item()
       if batch_idx % 100 == 0:
           # f-string
           print(f'[{batch_idx + 1:>5d}/{batch_num + 1:>5d}],loss:{loss.item()}')
   avg_loss = epoch_loss / batch_num
   avg_accuracy = corrent / size
   # loss_list.append(avg_loss)
   return avg_accuracy, avg_loss
def test(dataloader, net, loss_fn):
   size = len(dataloader.dataset)
   batch_num = len(dataloader)
   corrent = 0
   losses = 0
   net.eval()
   with torch.no_grad():
       for X, y in test_loader:
           pred = net(X)
           correct = (pred.argmax(1) == y).type(torch.int).sum().item()
           # print(y.size(0))
           # print(correct)
           corrent += correct
   accuracy = corrent / size
   avg_loss = losses / batch_num
   return accuracy, avg_loss



5.保存最优的模型

net.load_state_dict(torch.load('model_best.pth'))
   test(test_loader,net,loss_fn)


6.读入自己的写入数字,进行识别

model = MyNet()
model.load_state_dict(torch.load('model_best.pth'))
img = Image.open("7.png").convert("L")  # 转为灰度图像
img = transform(img)
# img = np.array(img)
# print(img)
result = model(img)
_, predict = torch.max(result.data, dim=1)
print(result)
print("the result is:",predict.item())


结语

minist是一个28*28的图像,所以输入就是28*28=784的维度,输出为10,0-9十个数字。手写数字识别首先需要初始化全局变量,构建数据集。然后构建模型,构建迭代器与损失函数,进行训练测试。最后可以将训练的模型进行保存,通过读取自己写的数字进行识别验证,完成一个简单的深度学习。

目录
相关文章
|
8月前
|
机器学习/深度学习 自然语言处理 算法
【模式识别】探秘判别奥秘:Fisher线性判别算法的解密与实战
【模式识别】探秘判别奥秘:Fisher线性判别算法的解密与实战
169 0
|
机器学习/深度学习 人工智能 异构计算
知识蒸馏的基本思路
知识蒸馏(Knowledge Distillation)是一种模型压缩方法,在人工智能领域有广泛应用。目前深度学习模型在训练过程中对硬件资源要求较高,例如采用GPU、TPU等硬件进行训练加速。但在模型部署阶段,对于复杂的深度学习模型,要想达到较快的推理速度,部署的硬件成本很高,在边缘终端上特别明显。而知识蒸馏利用较复杂的预训练教师模型,指导轻量级的学生模型训练,将教师模型的知识传递给学生网络,实现模型压缩,减少对部署平台的硬件要求,可提高模型的推理速度。
564 0
|
机器学习/深度学习 算法 Python
|
6月前
|
机器学习/深度学习 数据采集 算法
Python实现GBDT(梯度提升树)回归模型(GradientBoostingRegressor算法)项目实战
Python实现GBDT(梯度提升树)回归模型(GradientBoostingRegressor算法)项目实战
|
6月前
|
算法 Python
决策树算法详细介绍原理和实现
决策树算法详细介绍原理和实现
|
机器学习/深度学习 算法 数据可视化
决策树算法的原理是什么样的?
决策树算法的原理是什么样的?
235 0
决策树算法的原理是什么样的?
|
机器学习/深度学习 数据采集 资源调度
【机器学习】朴素贝叶斯分类器原理(理论+图解)
【机器学习】朴素贝叶斯分类器原理(理论+图解)
196 0
|
机器学习/深度学习 人工智能 算法
【机器学习】线性分类——感知机算法(理论+图解+公式推导)
【机器学习】线性分类——感知机算法(理论+图解+公式推导)
354 0
【机器学习】线性分类——感知机算法(理论+图解+公式推导)
|
机器学习/深度学习 人工智能 算法
【机器学习】集成学习(Boosting)——提升树算法(BDT)(理论+图解+公式推导)
【机器学习】集成学习(Boosting)——提升树算法(BDT)(理论+图解+公式推导)
297 0
【机器学习】集成学习(Boosting)——提升树算法(BDT)(理论+图解+公式推导)
|
机器学习/深度学习 人工智能 移动开发
【机器学习】线性分类——高斯判别分析GDA(理论+图解+公式推导)
【机器学习】线性分类——高斯判别分析GDA(理论+图解+公式推导)
465 0
【机器学习】线性分类——高斯判别分析GDA(理论+图解+公式推导)