深度学习网络训练,Loss出现Nan的解决办法

简介: 深度学习网络训练,Loss出现Nan的解决办法

前言

模型的训练不是单纯的调参,重要的是能针对出现的各种问题提出正确的解决方案。本文就训练网络loss出现Nan的原因做了具体分析,并给出了详细的解决方案,希望对大家训练模型有所帮助。


一、原因

一般来说,出现NaN有以下几种情况:

  1. 如果在迭代的100轮数以内,出现NaN,一般情况下的原因是你的学习率过高,需要降低学习率。可以不断降低学习率直至不出现NaN为止,一般来说低于现有学习率1-10倍即可。
  2. 如果当前的网络是类似于RNN的循环神经网络的话,出现NaN可能是因为梯度爆炸的原因,一个有效的方式是增加“gradient clipping”(梯度截断来解决)。
  3. 可能用0作了除数。
  4. 可能用0或者负数作为自然对数。
  5. 需要计算loss的数组越界(尤其是自己定义了一个新的网络,可能出现这种情况)。
  6. 在某些涉及指数计算,可能最后算得值为INF(无穷)(比如不做其他处理的softmax中分子分母需要计算ex(x),值过大,最后可能为INF/INF,得到NaN,此时你要确认你使用的softmax中在计算exp(x) 做了相关处理(比如减去最大值等等))。
  7. 训练深度网络的时候,label缺失问题也会导致loss一直是nan,需要检查label。

二、典型实例

1. 梯度爆炸

原因:梯度变得非常大,使得学习过程难以继续。

现象:观察log,注意每一轮迭代后的loss。loss随着每轮迭代越来越大,最终超过了浮点型表示的范围,就变成了NaN。

措施:

  • 减小solver.prototxt中的base_lr,至少减小一个数量级。如果有多个loss layer,需要找出哪个损失导致了梯度爆炸,并在train_val.prototxt中减小该层的loss_weight,而非是减小通用的base_lr。
  • 设置clip gradient,用于限制过大的diff。

2. 不当的损失函数

原因:有时候损失层中的loss的计算可能导致NaN的出现。比如,给InfogainLoss层(信息熵损失)输入没有归一化的值,使用带有bug的自定义损失层等等。

现象:观测训练产生的log时一开始并不能看到异常,loss也在逐步的降低,但突然之间NaN就出现了。

措施:看看你是否能重现这个错误,在loss layer中加入一些输出以进行调试。

3. 不当的输入

原因:输入中就含有NaN。

现象:每当学习的过程中碰到这个错误的输入,就会变成NaN。观察log的时候也许不能察觉任何异常,loss逐步的降低,但突然间就变成NaN了。

措施:重整你的数据集,确保训练集和验证集里面没有损坏的图片。调试中你可以使用一个简单的网络来读取输入层,有一个缺省的loss,并过一遍所有输入,如果其中有错误的输入,这个缺省的层也会产生NaN。

参考:https://zhuanlan.zhihu.com/p/599887666

目录
相关文章
|
1月前
|
机器学习/深度学习 数据可视化 算法
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
202 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
|
9天前
|
机器学习/深度学习 数据采集 算法
基于MobileNet深度学习网络的MQAM调制类型识别matlab仿真
本项目基于Matlab2022a实现MQAM调制类型识别,使用MobileNet深度学习网络。完整程序运行效果无水印,核心代码含详细中文注释和操作视频。MQAM调制在无线通信中至关重要,MobileNet以其轻量化、高效性适合资源受限环境。通过数据预处理、网络训练与优化,确保高识别准确率并降低计算复杂度,为频谱监测、信号解调等提供支持。
|
11天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
50 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
3天前
|
机器学习/深度学习 存储 算法
基于MobileNet深度学习网络的活体人脸识别检测算法matlab仿真
本内容主要介绍一种基于MobileNet深度学习网络的活体人脸识别检测技术及MQAM调制类型识别方法。完整程序运行效果无水印,需使用Matlab2022a版本。核心代码包含详细中文注释与操作视频。理论概述中提到,传统人脸识别易受非活体攻击影响,而MobileNet通过轻量化的深度可分离卷积结构,在保证准确性的同时提升检测效率。活体人脸与非活体在纹理和光照上存在显著差异,MobileNet可有效提取人脸高级特征,为无线通信领域提供先进的调制类型识别方案。
|
1月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
102 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
1月前
|
机器学习/深度学习 文件存储 异构计算
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
114 18
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
|
19天前
|
机器学习/深度学习 数据可视化 API
DeepSeek生成对抗网络(GAN)的训练与应用
生成对抗网络(GANs)是深度学习的重要技术,能生成逼真的图像、音频和文本数据。通过生成器和判别器的对抗训练,GANs实现高质量数据生成。DeepSeek提供强大工具和API,简化GAN的训练与应用。本文介绍如何使用DeepSeek构建、训练GAN,并通过代码示例帮助掌握相关技巧,涵盖模型定义、训练过程及图像生成等环节。
|
1月前
|
机器学习/深度学习 文件存储 异构计算
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
39 1
|
2月前
|
机器学习/深度学习 监控 算法
基于yolov4深度学习网络的排队人数统计系统matlab仿真,带GUI界面
本项目基于YOLOv4深度学习网络,利用MATLAB 2022a实现排队人数统计的算法仿真。通过先进的计算机视觉技术,系统能自动、准确地检测和统计监控画面中的人数,适用于银行、车站等场景,优化资源分配和服务管理。核心程序包含多个回调函数,用于处理用户输入及界面交互,确保系统的高效运行。仿真结果无水印,操作步骤详见配套视频。
70 18
|
2月前
|
机器学习/深度学习 运维 安全
深度学习在安全事件检测中的应用:守护数字世界的利器
深度学习在安全事件检测中的应用:守护数字世界的利器
105 22

热门文章

最新文章