训练神经网络的一些实用技巧

简介: 训练神经网络的一些实用技巧

神经网络参数的调节和选取一般都比较玄学,需要有比较丰富的经验才能训练出比较SOTA的网络。下面总结出几个比较常见且实用的训练技巧。

ce41a9a0ef3446a1656523fda98fdd47.jpg

为模型选择正确的最后一层激活和损失函数

batch_size的选择


使用大的batch size有害身体健康。更重要的是,它对测试集的error不利。一个真正的朋友不会让你使用大于32的batch size。直说了吧:2012年来人们开始转而使用更大batch size的原因只是我们的GPU不够强大,处理小于32的batch size时效率太低。这是个糟糕的理由,只说明了我们的硬件还很辣鸡。也就是最好的实验表现都是在batch size处于2~32之间得到的。因为batch_size越小时每次更新时由于没有使用全量数据而仅仅使用batch内数据,从而人为给训练带来了噪声,而这个操作却往往能够带领算法走出局部最优(鞍点)。当模型训练到尾声,想更精细化地提高成绩(比如论文实验/比赛到最后),有一个有用的trick,就是设置batch size为1,即做纯SGD,慢慢把error磨低。

一些技巧


一旦得到了具有统计功效的模型,问题就变成了:模型是否足够强大?它是否具有足够多的层和参数来对问题进行建模?例如,只有单个隐藏层且只有两个单元的网络,在 MNIST 问题上具有统计功效,但并不足以很好地解决问题。请记住,机器学习中无处不在的对立是优化和泛化的对立,理想的模型是刚好在欠拟合和过拟合的界线上,在容量不足和容量过大的界线上。为了找到这条界线,你必须穿过它。要搞清楚你需要多大的模型,就必须开发一个过拟合的模型,这很简单。

  • 添加更多的层。
  • 让每一层变得更大。
  • 训练更多的轮次。

要始终监控训练损失和验证损失,以及你所关心的指标的训练值和验证值。如果你发现模型在验证数据上的性能开始下降,那么就出现了过拟合。下一阶段将开始正则化和调节模型,以便尽可能地接近理想模型,既不过拟合也不欠拟合。

模型正则化与调节超参数


这一步是最费时间的:你将不断地调节模型、训练、在验证数据上评估(这里不是测试数据)、再次调节模型,然后重复这一过程,直到模型达到最佳性能。你应该尝试以下几项:

1)添加 dropout。

2)尝试不同的架构:增加或减少层数。

3)添加 L1 和 / 或 L2 正则化。

4) 尝试不同的超参数(比如每层的单元个数或优化器的学习率),以找到最佳配置。

5)(可选)反复做特征工程:添加新特征或删除没有信息量的特征。

请注意:每次使用验证过程的反馈来调节模型,都会将有关验证过程的信息泄露到模型中。如果只重复几次,那么无关紧要;但如果系统性地迭代许多次,最终会导致模型对验证过程过拟合(即使模型并没有直接在验证数据上训练)。这会降低验证过程的可靠性。

一旦开发出令人满意的模型配置,你就可以在所有可用数据(训练数据 + 验证数据)上训练最终的生产模型,然后在测试集上最后评估一次。如果测试集上的性能比验证集上差很多,那么这可能意味着你的验证流程不可靠,或者你在调节模型参数时在验证数据上出现了过拟合。在这种情况下,你可能需要换用更加可靠的评估方法,比如重复的 K 折验证。

相关文章
|
15小时前
|
机器学习/深度学习 算法 数据挖掘
神经网络训练失败的原因总结 !!
神经网络训练失败的原因总结 !!
46 0
|
15小时前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
90 0
|
15小时前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
100 0
|
15小时前
|
机器学习/深度学习
深度学习网络训练,Loss出现Nan的解决办法
深度学习网络训练,Loss出现Nan的解决办法
8 0
|
15小时前
|
机器学习/深度学习 并行计算 数据可视化
Batch Size 对神经网络训练的影响
Batch Size 对神经网络训练的影响
13 0
|
15小时前
|
机器学习/深度学习 数据可视化 数据挖掘
R语言深度学习卷积神经网络 (CNN)对 CIFAR 图像进行分类:训练与结果评估可视化
R语言深度学习卷积神经网络 (CNN)对 CIFAR 图像进行分类:训练与结果评估可视化
|
15小时前
|
机器学习/深度学习 算法 数据挖掘
SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型
SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型
|
15小时前
|
机器学习/深度学习 人工智能 自然语言处理
|
15小时前
|
机器学习/深度学习 数据采集 算法
|
15小时前
|
机器学习/深度学习 人工智能 算法
训练神经网络的7个技巧
训练神经网络的7个技巧
44 1

热门文章

最新文章