利用未标记数据的半监督学习在模型训练中的效果评估

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,5000CU*H 3个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
简介: 本文将介绍三种适用于不同类型数据和任务的半监督学习方法。我们还将在一个实际数据集上评估这些方法的性能,并与仅使用标记数据的基准进行比较。

数据科学家在实践中经常面临的一个关键挑战是缺乏足够的标记数据来训练可靠且准确的模型。标记数据对于监督学习任务(如分类或回归)至关重要。但是在许多领域,获取标记数据往往成本高昂、耗时或不切实际。相比之下,未标记数据通常较易获取,但无法直接用于模型训练。

如何利用未标记数据来改进监督学习模型?这正是半监督学习的应用场景。半监督学习是机器学习的一个分支,它结合标记和未标记数据来训练模型,旨在获得比仅使用标记数据更优的性能。半监督学习的基本原理是,未标记数据可以提供关于数据底层结构、分布和多样性的有用信息,从而帮助模型更好地泛化到新的和未见过的样本。

本文将介绍三种适用于不同类型数据和任务的半监督学习方法。我们还将在一个实际数据集上评估这些方法的性能,并与仅使用标记数据的基准进行比较。

半监督学习概述

半监督学习是一种利用标记和未标记数据来训练模型的机器学习方法。标记数据是指具有已知输出或目标变量的样本,例如分类任务中的类别标签或回归任务中的数值。未标记数据则是没有已知输出或目标变量的样本。半监督学习的优势在于它可以利用现实问题中通常大量存在的未标记数据,同时也充分利用通常较少且获取成本较高的标记数据。

利用未标记数据训练监督学习模型的核心思想是通过监督或无监督学习方法为这些数据生成标签。尽管这些生成的标签可能不如实际标签准确,但大量这样的数据仍可以显著提高监督学习方法的性能,相比于仅在有限的标记数据上训练模型。

scikit-learn库提供了三种半监督学习方法:

  1. 自训练(Self-training):首先在标记数据上训练分类器,用于预测未标记数据的标签。在后续迭代中,另一个分类器在标记数据和高置信度的未标记数据预测结果上进行训练。此过程重复进行,直到没有新的高置信度标签被预测或达到最大迭代次数。
  2. 标签传播(Label Propagation):构建一个图结构,其中节点表示数据点,边表示它们之间的相似性。标签通过图结构迭代传播,使算法能够基于未标记数据点与标记数据的连接关系为其分配标签。
  3. 标签扩散(Label Spreading):采用与标签传播类似的概念。不同之处在于标签扩散使用软分配策略,即标签根据数据点之间的相似性进行迭代更新。此方法还可能"覆盖"标记数据集中的原始标签。

为评估这些方法,本文使用了糖尿病预测数据集,其中包含患者的各种特征数据,如年龄和BMI,以及表示患者是否患有糖尿病的标签。该数据集共包含100,000条记录,我们将其随机划分为80,000条训练数据、10,000条验证数据和10,000条测试数据。为分析学习方法对标记数据数量的敏感性,还将训练数据进一步划分为标记集和未标记集,其中标记数据的数量作为一个可变参数。

数据集划分示意图

我们使用验证数据集来评估不同的参数设置,并使用测试数据集来评估参数调优后各方法的最终性能。选择XGBoost作为预测模型,并使用F1分数作为性能评估指标。

基准模型

为了比较自学习算法与不使用未标记数据的情况,我们首先建立一个基准模型。在不同大小的标记数据集上训练XGBoost模型,并在验证数据集上计算F1分数:

基准模型F1分数

结果显示,当训练样本少于100个时,F1分数相对较低。随着样本量增加到1,000,F1分数稳步提升至约79%。继续增加样本量后,F1分数的提升幅度变得微小。

自学习方法

自学习(Self-training)是一种迭代过程,用于为未标记数据生成标签,并在后续迭代中将这些生成的标签用于模型训练。选择预测结果作为下一次迭代的标记数据主要有两种策略:

  1. 阈值法(默认):选择所有置信度高于预设阈值的预测结果。
  2. K最佳法:选择置信度最高的K个预测结果。

我们评估了以下三种配置:

  • ST默认:使用默认参数的自学习
  • ST阈值调优:基于验证数据集调整阈值的自学习
  • ST KB调优:基于验证数据集调整K值的自学习

这些模型的性能在测试数据集上进行了评估,结果如下图所示:

自学习方法性能比较

分析结果显示:

  • 对于小样本量(<100),默认参数配置(红线)的表现不如基准模型(蓝线)。
  • 在较大样本量下,自学习方法略优于基准模型。
  • 阈值调优(绿线)带来了显著的性能提升。例如,在标记数据量为200时,基准模型的F1分数为57%,而使用调优阈值的自学习算法达到了70%。
  • K最佳法调优(紫线)的性能与基准模型相近,仅在标记数据量为30时出现例外。

标签传播方法

标签传播算法内置了两种核函数:RBF(径向基函数)和KNN(K近邻)。RBF核使用密集矩阵生成完全连接的图,这对大型数据集而言计算成本高且内存密集。考虑到内存限制,所以我们对RBF核的训练规模上限设为3,000个样本。KNN核利用更节省内存的稀疏矩阵表示,使我们能够在全部训练数据(最多80,000个样本)上进行拟合。下图比较了这两种核函数方法的性能:

标签传播方法性能比较

图中展示了不同标签传播配置在测试数据集上的F1分数,作为标记数据量的函数。主要观察结果如下:

  • 蓝线代表基准模型,与自学习实验中的基准相同。
  • 红线表示使用默认参数的标签传播,其性能在所有标记数据量下均明显低于基准模型。
  • 绿线代表使用RBF核和经过调优的gamma参数的标签传播。Gamma参数定义了单个训练样本的影响范围。调优后的RBF核在小样本量(<=100)时优于基准模型,但在较大样本量时表现较差。
  • 紫线代表使用KNN核和经过调优的K参数的标签传播,K确定了要使用的最近邻数量。KNN核的整体表现与RBF核相似。

标签扩散方法

标签扩散是一种与标签传播相似的方法,但引入了一个额外的参数alpha,用于控制实例采纳邻居信息的程度。Alpha取值范围为0到1,其中0表示实例完全保持其原始标签,1表示完全采纳邻居的标签。我们对标签扩散的RBF和KNN核方法都进行了参数调优。标签扩散的性能结果如下图所示:

标签扩散方法性能比较

分析结果显示:

  • 标签扩散的整体性能趋势与标签传播非常相似,但存在一个显著的例外。
  • 标签扩散的RBF核方法在所有标记数据量下的测试分数均低于基准模型,而不仅仅是在小样本量情况下。
  • 这一现象表明,对于本数据集,邻居标签的"覆写"效果可能产生了负面影响。这可能是由于数据集中异常值或噪声标签较少导致的。
  • 相比之下,KNN核方法似乎不受alpha参数的显著影响。这暗示alpha参数主要与RBF核方法相关。

综合比较

为了全面评估各种半监督学习方法的性能,我们将所有使用最优参数配置的方法进行了对比。下图展示了这一综合比较的结果:

各方法最优性能比较

图中展示了不同半监督学习方法在测试数据集上的F1分数,作为标记数据量的函数。主要观察结果如下:

  1. 自训练方法(Self-training)在大多数情况下表现优于基准模型。这表明该方法能够有效利用未标记数据来提升模型性能。
  2. 标签传播(Label Propagation)和标签扩散(Label Spreading)方法仅在标记数据量较小时优于基准模型。随着标记数据量的增加,这两种方法的性能相对下降。
  3. 在标记数据量较大时,基准模型的性能接近或超过了半监督学习方法。这说明当有足够的标记数据时,传统的监督学习方法可能已经足够有效。

研究结论

基于本研究的实验结果得出以下主要结论:

  1. 方法有效性因数据集而异:半监督学习方法的性能可能会因不同的数据集、分类器算法和评估指标而显著变化。因此不应将本研究的发现直接推广到其他应用场景,而应在具体应用中进行适当的测试和验证。
  2. 参数调优的重要性:参数调优对于显著提高半监督学习方法的性能至关重要。例如,经过优化的自训练方法在各种标记数据量下均优于基准模型,F1分数最高提升可达13个百分点。
  3. 方法选择的权衡:标签传播和标签扩散方法仅在非常小的样本量时表现出性能优势。使用这些方法时需要格外谨慎,以避免在某些情况下获得比不使用半监督学习更差的结果。
  4. 自训练方法的稳健性:在文中自训练方法展现出了较好的稳健性和性能提升,特别是在中等规模的标记数据集上。
  5. 计算资源考虑:RBF核方法在大规模数据集上可能面临计算资源限制,而KNN核方法在这方面表现更为灵活。

实际应用建议

我们为实际应用中的半监督学习提出以下建议:

  1. 方法选择:在标记数据有限的情况下,优先考虑使用自训练方法。对于极小规模的标记数据集,可以尝试标签传播或标签扩散方法。
  2. 参数调优:无论选择哪种半监督学习方法,都应进行充分的参数调优。这可能会显著提高模型性能。
  3. 性能评估:始终将半监督学习方法与仅使用标记数据的基准模型进行比较,以确保所选方法确实带来了性能提升。
  4. 数据质量:关注未标记数据的质量和相关性。高质量的未标记数据更可能对模型性能产生积极影响。
  5. 计算资源平衡:在选择核方法时,要考虑可用的计算资源。对于大规模数据集,KNN核可能是更实用的选择。

未来研究方向

为进一步推进半监督学习在实际应用中的有效性,建议以下几个潜在的研究方向:

  1. 探索更多类型的数据集和应用场景,以评估半监督学习方法的泛化能力。
  2. 研究如何自动选择最适合特定数据集和任务的半监督学习方法。
  3. 开发更高效的参数调优策略,以减少计算开销。
  4. 调查半监督学习方法在处理不平衡数据集和多类分类问题时的效果。
  5. 结合其他机器学习技术(如迁移学习或主动学习)与半监督学习,探索可能的协同效应。

如果你想手动研究和测试,本文代码在这里:

https://avoid.overfit.cn/post/9db781a0b2e941fca451fed2be16b09a

作者:Reinhard Sellmair

目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 专有云
人工智能平台PAI使用问题之怎么将DLC的数据写入到另一个阿里云主账号的OSS中
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
16天前
|
机器学习/深度学习 算法 数据挖掘
Python数据分析革命:Scikit-learn库,让机器学习模型训练与评估变得简单高效!
在数据驱动时代,Python 以强大的生态系统成为数据科学的首选语言,而 Scikit-learn 则因简洁的 API 和广泛的支持脱颖而出。本文将指导你使用 Scikit-learn 进行机器学习模型的训练与评估。首先通过 `pip install scikit-learn` 安装库,然后利用内置数据集进行数据准备,选择合适的模型(如逻辑回归),并通过交叉验证评估其性能。最终,使用模型对新数据进行预测,简化整个流程。无论你是新手还是专家,Scikit-learn 都能助你一臂之力。
70 8
|
18天前
|
机器学习/深度学习 数据采集 监控
探索机器学习:从数据到决策
【9月更文挑战第18天】在这篇文章中,我们将一起踏上一段激动人心的旅程,穿越机器学习的世界。我们将探讨如何通过收集和处理数据,利用算法的力量来预测未来的趋势,并做出更加明智的决策。无论你是初学者还是有经验的开发者,这篇文章都将为你提供新的视角和思考方式。
|
24天前
|
机器学习/深度学习 算法 数据挖掘
从菜鸟到大师:Scikit-learn库实战教程,模型训练、评估、选择一网打尽!
【9月更文挑战第13天】在数据科学与机器学习领域,Scikit-learn是不可或缺的工具。本文通过问答形式,指导初学者从零开始使用Scikit-learn进行模型训练、评估与选择。首先介绍了如何安装库、预处理数据并训练模型;接着展示了如何利用多种评估指标确保模型性能;最后通过GridSearchCV演示了系统化的参数调优方法。通过这些实战技巧,帮助读者逐步成长为熟练的数据科学家。
67 3
|
2月前
|
监控 数据安全/隐私保护 异构计算
借助PAI-EAS一键部署ChatGLM,并应用LangChain集成外部数据
【8月更文挑战第8天】借助PAI-EAS一键部署ChatGLM,并应用LangChain集成外部数据
66 1
|
2月前
|
机器学习/深度学习 数据挖掘
机器学习模型的选择与评估:技术深度解析
【8月更文挑战第21天】机器学习模型的选择与评估是一个复杂而重要的过程。通过深入理解问题、选择合适的评估指标和交叉验证方法,我们可以更准确地评估模型的性能,并选择出最适合当前问题的模型。然而,机器学习领域的发展日新月异,新的模型和评估方法不断涌现。因此,我们需要保持对新技术的学习和关注,不断优化和改进我们的模型选择与评估策略。
|
2月前
|
数据采集 机器学习/深度学习 算法
"揭秘数据质量自动化的秘密武器:机器学习模型如何精准捕捉数据中的‘隐形陷阱’,让你的数据分析无懈可击?"
【8月更文挑战第20天】随着大数据成为核心资源,数据质量直接影响机器学习模型的准确性和效果。传统的人工审查方法效率低且易错。本文介绍如何运用机器学习自动化评估数据质量,解决缺失值、异常值等问题,提升模型训练效率和预测准确性。通过Python和scikit-learn示例展示了异常值检测的过程,最后强调在自动化评估的同时结合人工审查的重要性。
58 2
|
2月前
|
机器学习/深度学习 JSON API
【Python奇迹】FastAPI框架大显神通:一键部署机器学习模型,让数据预测飞跃至Web舞台,震撼开启智能服务新纪元!
【8月更文挑战第16天】在数据驱动的时代,高效部署机器学习模型至关重要。FastAPI凭借其高性能与灵活性,成为搭建模型API的理想选择。本文详述了从环境准备、模型训练到使用FastAPI部署的全过程。首先,确保安装了Python及相关库(fastapi、uvicorn、scikit-learn)。接着,以线性回归为例,构建了一个预测房价的模型。通过定义FastAPI端点,实现了基于房屋大小预测价格的功能,并介绍了如何运行服务器及测试API。最终,用户可通过HTTP请求获取预测结果,极大地提升了模型的实用性和集成性。
141 1
|
2月前
|
缓存 开发者 测试技术
跨平台应用开发必备秘籍:运用 Uno Platform 打造高性能与优雅设计兼备的多平台应用,全面解析从代码共享到最佳实践的每一个细节
【8月更文挑战第31天】Uno Platform 是一种强大的工具,允许开发者使用 C# 和 XAML 构建跨平台应用。本文探讨了 Uno Platform 中实现跨平台应用的最佳实践,包括代码共享、平台特定功能、性能优化及测试等方面。通过共享代码、采用 MVVM 模式、使用条件编译指令以及优化性能,开发者可以高效构建高质量应用。Uno Platform 支持多种测试方法,确保应用在各平台上的稳定性和可靠性。这使得 Uno Platform 成为个人项目和企业应用的理想选择。
38 0
下一篇
无影云桌面