如何让训练神经网络不无聊?试试迁移学习和多任务学习

简介:

训练深度神经网络是一个乏味的过程。更实际的方法,如重新使用训练好的网络解决其他任务,或针对许多任务使用相同的网络。这篇文章中,我们会讨论两个重要的方法:迁移学习和多任务学习。

迁移学习

在迁移学习中,我们希望利用源任务学到的知识帮助学习目标任务。例如,一个训练好的图像分类网络能够被用于另一个图像相关的任务。再比如,一个网络在仿真环境学习的知识可以被迁移到真实环境的网络。

总的来说,神经网络迁移学习有两种方案:特征提取和微调。迁移学习一个典型的例子就是载入训练好VGG网络,这个大规模分类网络能将图像分到1000个类别,然后把这个网络用于另一个任务,如医学图像分类。

如何让训练神经网络不无聊?试试迁移学习和多任务学习

1特征提取:

特征提取是针对目标任务把一个简单的分类器加在源任务上预训练的网络上,将预训练的网络作为特征提取器。仅有添加的分类器的参数需要更新,预训练的网络的参数不变。这能使新任务从源任务中学习到的特征中受益。但是,这些特征更加适合源任务。

2微调:

微调允许学习目标任务时修改预训练的网络参数。通常,在预训练的网络之上加一个新的随机初始化的层。预训练网络的参数使用很小的学习率更新防止大的改变。通常会冻结网络底层的参数,这些层学到更通用的特征,微调顶部的层,这些层学到更具体的特征。同时,冻结一些层能够减少需要训练的参数的数量,避免过拟合问题,尤其时在目标任务数据量不够大的情况下。实践中,微调胜过特征提取因为他针对新的任务优化了预训练的网络。

迁移学习的基本情形:

迁移学习可以分为4种情形基于以下两个因素:1)目标任务数据集的大小,2)源任务与目标任务的相似度:

情形1:目标数据集很小,目标任务与源任务相似:这种情况使用特征提取,因为目标数据集小容易造成过拟合。

情形2:目标数据集很小,目标任务与源任务不同:这时我们微调底层网络,并移除高层网络。换句话说,我们使用较早的特征提取。

情形3:目标数据集很大,目标任务与源任务相似:我们有了大量的数据,我们可以随机初始化参数,从头开始训练网络。然而,最好还是使用预训练的网络初始化参数并微调几层。

情形4:目标数据集很大,目标任务与源任务不同。这时,我们微调大部分层甚至整个网络。

多任务学习

多任务学习的主要目标是通过使用多个任务的样本优化网络的参数改进任务的性能。例如,我们希望有一个网络可以根据输入的脸部图像区分是男性还是女性,同时可以预测这个人的年龄。这时,我们有两个相关的任务,一个是二分类,一个是回归任务。显然两个任务是相关的,对一个任务的学习可以改进另外一个任务。

如何让训练神经网络不无聊?试试迁移学习和多任务学习

一个简单的网络设计实例,可以在任务和任务之间共享一部分网络。共享部分学习任务通用的中间表达,有助于这些共同的学习任务。另一方面,针对特定的学习任务,特定的头部会学习如何使用这些共享表达。


原文发布时间为:2018-05-29

本文来自云栖社区合作伙伴“雷锋网”,了解相关信息可以关注“雷锋网”。

相关文章
|
2月前
|
机器学习/深度学习
神经网络与深度学习---验证集(测试集)准确率高于训练集准确率的原因
本文分析了神经网络中验证集(测试集)准确率高于训练集准确率的四个可能原因,包括数据集大小和分布不均、模型正则化过度、批处理后准确率计算时机不同,以及训练集预处理过度导致分布变化。
|
1月前
|
监控 网络协议 Linux
网络学习
网络学习
133 68
|
11天前
|
网络协议 网络架构
网络协议介绍与学习
网络协议介绍与学习
28 4
|
11天前
|
网络协议 网络安全 数据安全/隐私保护
网络基础知识学习
如果你打算深入学习网络技术,建议从上述基础知识入手,并逐渐扩展到更高级的主题,如网络编程、网络安全、网络管理等。同时,实践是学习网络技术的关键,可以通过搭建自己的小型网络环境来进行实验和探索。
13 2
|
1月前
|
机器学习/深度学习 数据采集 数据可视化
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
本文详细介绍如何使用PyTorch构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行图像分类。从数据预处理、模型定义到训练过程及结果可视化,文章全面展示了深度学习项目的全流程。通过实际操作,读者可以深入了解CNN在图像分类任务中的应用,并掌握PyTorch的基本使用方法。希望本文为您的深度学习项目提供有价值的参考与启示。
|
1月前
|
网络协议 安全 网络安全
网络基础知识学习
【9月更文挑战第1天】
48 0
|
2月前
|
机器学习/深度学习
|
1月前
|
安全 Linux 网络安全
网络安全学习
【9月更文挑战第1天】
54 0
|
2月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
36 0
|
2月前
|
机器学习/深度学习 数据采集 TensorFlow
从零到精通:TensorFlow与卷积神经网络(CNN)助你成为图像识别高手的终极指南——深入浅出教你搭建首个猫狗分类器,附带实战代码与训练技巧揭秘
【8月更文挑战第31天】本文通过杂文形式介绍了如何利用 TensorFlow 和卷积神经网络(CNN)构建图像识别系统,详细演示了从数据准备、模型构建到训练与评估的全过程。通过具体示例代码,展示了使用 Keras API 训练猫狗分类器的步骤,旨在帮助读者掌握图像识别的核心技术。此外,还探讨了图像识别在物体检测、语义分割等领域的广泛应用前景。
13 0

热门文章

最新文章