训练模型

简介: 【8月更文挑战第1天】

在Python目录下新建NLP目录,在NLP目录下新建speech目录,在speech目录下新建speech_commands项目目录并将源码下载到该项目目录下。
在开始训练过程之前需要获取数据集。有两种方式可获取数据集:一种是手动下载语音指令数据集,下载后的数据集解压后需要放在speech\speech_commands\tmp下;另一种是运行训练程序后脚本会自动下载该数据集。语音指令数据集中包括超过105000个WAVE音频文件,音频内容是30个不同的字词,这些数据由Google收集,并依据知识共享许可协议(Creative Commons License,CC协议)发布,读者可以提交5min的录音来帮助改进该数据集。
运行speech_commands目录下的train.py,在训练过程中会打印出训练日志,如图所示。

打印出的训练日志
以第一步的打印日志为例:
训练步Step #1,共设置了18000个训练步,可以通过该信息观察训练进程;学习率rate 0.001000,刚开始学习率比较大,为0.001,训练后期会减小到0.000 1;准确率accuracy 5.0%,表示训练在本步中预测正确的类别数量,该值通常会有较大的波动,但会随着训练的进行总体有所提高,准确率的值会在0%~100%之间波动,始终不会超过100%;一般而言,这个值越高,训练出来的模型越好;损失函数的值cross entropy 2.560931,表示训练过程的损失函数的结果,它是一个得分,通过将当前训练运行的得分向量与正确标签进行比较计算得出,该得分应在训练期间呈下滑趋势,但是其并不一定是平滑地下滑。
在训练过程中,每100步会保存一次模型,第100步打印的日志如图所示。

第100步打印的日志
由于训练过程比较久,如果是CPU的TensorFlow训练,需要十几个小时,所以可能一次无法完成训练。当中途有其他事情的时候可以先中止训练,下次训练时先检查上次训练到哪一步了,查找上次保存的检查点,然后将“--start_checkpoint=tmp/speech_commands_train/conv.ckpt-100”用作命令行参数重启该脚本,从该点继续训练,命令行参数中的100是上次中止训练的最近的步数,同时也包含在检查点文件名中。
每训练400步,会生成混淆矩阵(Confusion Matrix)。在监督学习中,混淆矩阵可作为可视化工具,在无监督学习中其一般被称为匹配矩阵。混淆矩阵是通过将实际的分类与预测分类相比较计算出来的。
由于该语音识别项目中生成的混淆矩阵相对比较复杂,所以这里先给出一个实例来进行概念理解,如现有一个猫狗分类的二分类问题,有10只猫和8只狗,在预测之后计算出的混淆矩阵如表所示。
预测之后计算出的混淆矩阵
混淆矩阵 预测值
猫 狗
真实标签 猫 7 3
狗 0 8

由上表可知,混淆矩阵中数字的每一行都是真实标签(即第一行的7和3,第二行的0和8),数字的每一列是预测值(即第一列的7和0,第二列的3和8)。实际上猫的数量为7+3=10,狗的数量为0+8=8,但是预测的结果中,猫的数量为7+0=7,狗的数量为3+8=11。在每次输出混淆矩阵之后会打印该模型在验证集上的准确率(Validation accuracy),该准确率是由混淆矩阵从左上到右下的对角线上值的和除以总体数据量N得到的,此时的预测准确率可以计算为(7+8)÷18×100%≈83.3%。在理想情况下,该准确率应该接近于训练准确率,如果训练准确率有所提高,但验证准确率没有提高,则表明存在过拟合,模型只学习了有关训练数据集的信息,而在验证或者预测数据集上表现不佳,所以混淆矩阵也可以看作预测准确率的可视化。
下图是训练到400步时生成的混淆矩阵。

训练到400步时生成的混淆矩阵
在语音识别项目中,标签分别为“silence”“unknown”“yes”“no”“up”“down”“left”“right”“on”“off”“stop”“go”。
每一行都代表真实的标签,如第一行是“silence”(无声的)的所有音频片段,第二行是“unknown”(未知字词的)的所有音频片段,第三行是“yes”的所有音频片段,以此类推。
每一列都代表预测的值,第一列代表预测为“silence”(无声的)的所有音频片段,第二列代表预测为“unknown”(未知字词的)的所有音频片段,第三列代表预测为“yes”的所有音频片段,以此类推。
训练到18000步时的混淆矩阵如图所示,可以看出它与训练到400步时混淆矩阵的区别。

训练到18000步时的混淆矩阵
所以,完美模型生成的混淆矩阵,除了从左上到右下的对角线上的条目外,所有其他条目几乎都接近0。混淆矩阵有助于了解模型最容易在哪些方面混淆,确定问题所在后,可以通过添加更多数据或清理类别来解决问题。所以,混淆矩阵要比直接打印准确率和损失函数的值更能够准确地找到网络的问题。
在整个训练过程中,使用TensorBoard可以很好地观察训练进度。默认情况下,脚本会将事件保存到tmp/retrain_logs,可以在命令行运行以下命令:
tensorboard --logdir tmp/retrain_logs
模型进度的图表如图所示。

模型进度的图表
训练完成后,准确率介于85%~90%之间。在训练完成后进行模型转换,在命令行运行以下命令:
python speech_commands/freeze.py \
--start_checkpoint=speech_commands/tmp/speech_commands_train/conv.ckpt-18000 \
--output_file=speech_commands/tmp/my_frozen_graph.pb
运行以上命令会在tmp目录下生成my_frozen_graph.pb模型文件。

相关实践学习
通过日志服务实现云资源OSS的安全审计
本实验介绍如何通过日志服务实现云资源OSS的安全审计。
相关文章
|
4月前
|
存储 运维 监控
120_检查点管理:故障恢复 - 实现分布式保存机制
在大型语言模型(LLM)的训练过程中,检查点管理是确保训练稳定性和可靠性的关键环节。2025年,随着模型规模的不断扩大,从百亿参数到千亿参数,训练时间通常长达数周甚至数月,硬件故障、软件错误或网络中断等问题随时可能发生。有效的检查点管理机制不仅能够在故障发生时快速恢复训练,还能优化存储使用、提高训练效率,并支持实验管理和模型版本控制。
120_检查点管理:故障恢复 - 实现分布式保存机制
|
5月前
|
人工智能 文字识别 运维
AR眼镜在巡检业务中的软件架构设计|阿法龙XR云平台
引入AR眼镜与AI融合的巡检方案,构建“端-边-云”协同架构,实现工单可视化、AR叠加数据、智能识别表计与异常、远程协作及自动报告生成,提升工业巡检效率与智能化水平。
|
存储 分布式计算 大数据
大数据 优化数据读取
【11月更文挑战第4天】
355 2
|
11月前
|
SQL 关系型数据库 MySQL
|
12月前
|
机器学习/深度学习 存储 人工智能
《LSTM与HMM:序列建模领域的双雄对决》
长短期记忆网络(LSTM)和隐马尔可夫模型(HMM)是序列建模中的重要工具。两者都能处理序列数据并基于概率预测,且都使用状态概念建模。然而,LSTM通过门控机制捕捉复杂长期依赖,适用于长序列任务;HMM基于马尔可夫假设,适合短期依赖关系。LSTM训练复杂、适应性强但解释性差,而HMM训练简单、解释性好,适用于离散数据。两者在不同场景中各有优势。
281 7
|
数据安全/隐私保护
3分钟部署 七日杀(7DaysToDie)联机服务
通过计算巢快速部署 七日杀 联机服务。
3分钟部署 七日杀(7DaysToDie)联机服务
|
NoSQL Java Linux
CentOS7下部署Graylog开源日志管理系统
CentOS7下部署Graylog开源日志管理系统
1369 0
CentOS7下部署Graylog开源日志管理系统
|
C++ 存储 安全
C++中`std::function`和`std::bind`的详细解析
C++中`std::function`和`std::bind`的详细解析
497 0
C++中`std::function`和`std::bind`的详细解析
|
Ubuntu Unix Linux
合肥中科深谷嵌入式项目实战——基于ARM语音识别的智能家居系统(一)
合肥中科深谷嵌入式项目实战——基于ARM语音识别的智能家居系统(一)
|
JSON API 语音技术
Android语音识别(本地+第三方)
Android语音识别(本地+第三方)
1197 0
Android语音识别(本地+第三方)