一份半监督学习的指南-伪标签学习

简介: 在ML中,有3种机器学习方法-监督学习、无监督学习和强化学习技术。 我们所知道的监督学习是指数据带有标签的情况, 无监督学习是仅存在数据而没有标签的情况,强化学习算法的思路非常简单,以游戏为例,如果在游戏中采取某种策略可以取得较高的得分,那么就进一步“强化”这种策略,以期继续取得较好的结果。

1 引言


在ML中,有3种机器学习方法-监督学习、无监督学习和强化学习技术。 我们所知道的监督学习是指数据带有标签的情况, 无监督学习是仅存在数据而没有标签的情况,强化学习算法的思路非常简单,以游戏为例,如果在游戏中采取某种策略可以取得较高的得分,那么就进一步“强化”这种策略,以期继续取得较好的结果。


想象一下这样一种情况,在训练中,标记数据的数量更少,而未标记数据的数量更多。 一种称为半监督学习( [Semi-Supervised Learning],SSL)的新技术,它是监督学习和非监督学习的混合体。 顾名思义,半监督学习中同时存在一组标记的训练数据和另一组未标记的训练数据。 我们可以将这种情况想像成Google图片或Facebook通过其面孔(数据)识别出图片中的人物并根据该人物先前存储的图像生成建议名称(标签)的情况。


41.png


在本文中,我们将讨论如何使用半监督学习技术生成伪标签。


2 Pseudo-Labelling 伪标签


伪标签是使用标记的数据模型预测未标记数据并进行标记的过程。 首先,模型已经训练了包含标签的数据集,该模型用于为未标记的数据集生成伪标签。 最后,将数据集和标签(原始标签和伪标签)组合在一起以进行最终模型训练。 之所以称为伪(意味着虚幻),是因为它们可能是真实标签,也可能不是真实标签,并且是通过我们基于类似的数据模型生成的标签。


42.png


该方法的主旨思想其实很简单。首先,在标签数据上训练模型,然后使用经过训练的模型来预测无标签数据的标签,从而创建伪标签。此外,将标签数据和新生成的伪标签数据结合起来作为新的训练数据。


3 Python 实现


在这个例子中,我们使用了sklearn中的breast cancer数据集。我们知道整个已经包含了标签,但我们要修改它,将数据分成两部分,一部分有标签,另一部分没有标签。我们将从经过训练的带标签数据模型中为未带标签的数据生成我们自己的标签,然后最后使用两者合并的数据集来训练最终的模型。


3.1 数据集


Breast cancer dataset是预测肿瘤是良性(B)还是恶性(M)的分类问题。前两列为1)id和2)diagnosis(标签):


43.png

a)radius_mean(从中心到外围点的距离的平均值)
b)texture_mean(灰度值的标准偏差)
c)perimeter_mean(周长)
d)area_mean(面积)
e)smoothness_mean(半径长度的局部变化)
f)compactness_mean(周长^ 2 /面积– 1.0)
g)concavity_mean(轮廓凹部的严重程度)
h) concave points_mean(轮廓的凹面部分的数量)


3.2 导入包

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier


3.3 加载数据集

X,y = load_breast_cancer(True)
X.shape

(569, 30)


3.4 分割数据集

x_train,x_test,y_train,_ = train_test_split(X,y,test_size=.6)
x_train.shape,y_train.shape,x_test.shape

((227, 30), (227,), (342, 30)


3.5 训练模型

model1 = RandomForestClassifier()
history = model1.fit(x_train,y_train)
history

RandomForestRegressor(bootstrap=True, ccp_alpha=0.0, criterion=’mse’,
max_depth=None, max_features=’auto’, max_leaf_nodes=None,
max_samples=None, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
n_estimators=100, n_jobs=None, oob_score=False,
random_state=None, verbose=0, warm_start=False)


3.6 评分

model1.score(x_train,y_train)

1.0


3.7 预测

y_new = model1.predict(x_test)
y_new.shape

(342,)


合并数据集

final_X = np.concatenate((x_train,x_test))
final_X.shape

(569, 30)


合并原始标签与伪标签

final_Y = np.concatenate((y_train,y_test))
final_Y.shape

(569,)


基于合并的数据集训练最终模型

model2 = RandomForestRegressor()
model2.fit(final_X,final_Y)
model2.score(final_X,final_Y)

1.0


4 结论


伪标签的实现到此为止,大家可以根据自己的想法去比赛中尝试吧。


相关文章
|
机器学习/深度学习 计算机视觉
秒懂Precision精确率、Recall召回率-附代码和案例
秒懂Precision精确率、Recall召回率-附代码和案例
|
监控 JavaScript 前端开发
影刀RPA(初级)(二)
影刀RPA(初级)(二)
9076 2
|
Linux 开发工具 git
[笔记]ubuntun18.0+clion+qt5 搭建跨平台应用环境
[笔记]ubuntun18.0+clion+qt5 搭建跨平台应用环境
446 0
|
机器学习/深度学习 人工智能 自然语言处理
Emotion-LLaMA:用 AI 读懂、听懂、看懂情绪,精准捕捉文本、音频和视频中的复杂情绪
Emotion-LLaMA 是一款多模态情绪识别与推理模型,融合音频、视觉和文本输入,通过特定情绪编码器整合信息,广泛应用于人机交互、教育、心理健康等领域。
1223 11
Emotion-LLaMA:用 AI 读懂、听懂、看懂情绪,精准捕捉文本、音频和视频中的复杂情绪
|
11月前
|
机器学习/深度学习 人工智能 API
解锁HarmonyOS新姿势:金融风控中的AI类目标签实战
在金融行业中,风险控制是保障稳定与安全的核心。随着业务复杂化和数字化加深,传统风控手段难以应对新挑战。AI类目标签技术凭借强大的数据处理能力,为金融风控带来全新解决方案。本文探讨基于HarmonyOS NEXT API 12及以上版本,如何运用AI类目标签技术构建高效金融风控体系,助力开发者在鸿蒙生态中创新应用。通过精准风险识别、实时监测预警和优化信用评估,提升风控效果;结合鸿蒙系统的分布式软总线和隐私保护优势,实现无缝协同与数据安全。具体应用场景如信用卡欺诈防控和贷款审批风险评估,展示了技术的实际效益。
378 0
|
机器学习/深度学习 人工智能 自然语言处理
一文速通半监督学习(Semi-supervised Learning):桥接有标签与无标签数据
一文速通半监督学习(Semi-supervised Learning):桥接有标签与无标签数据
1348 0
|
存储 缓存 Dart
Flutter&鸿蒙next 封装 Dio 网络请求详解:登录身份验证与免登录缓存
本文详细介绍了如何在 Flutter 中使用 Dio 封装网络请求,实现用户登录身份验证及免登录缓存功能。首先在 `pubspec.yaml` 中添加 Dio 和 `shared_preferences` 依赖,然后创建 `NetworkService` 类封装 Dio 的功能,包括请求拦截、响应拦截、Token 存储和登录请求。最后,通过一个登录界面示例展示了如何在实际应用中使用 `NetworkService` 进行身份验证。希望本文能帮助你在 Flutter 中更好地处理网络请求和用户认证。
773 1
|
存储 Java 数据库连接
南大通用GBase 8s大对象类型clob和text的比较说明
本文探讨了GBase数据库中用于存储大对象数据的字段类型,包括TEXT、CLOB、BYTE和BLOB,分析了它们的特点、适用场景及在实际应用中的最佳实践。重点介绍了不同数据大小对应的字段类型选择,以及在数据库工具和程序中操作这些类型的方法,强调了合理选择字段类型对提升数据库性能的重要性。
|
SQL 分布式计算 DataWorks
DataWorks产品使用合集之如何对多个表进行历史数据的回刷(即补数据)
DataWorks作为一站式的数据开发与治理平台,提供了从数据采集、清洗、开发、调度、服务化、质量监控到安全管理的全套解决方案,帮助企业构建高效、规范、安全的大数据处理体系。以下是对DataWorks产品使用合集的概述,涵盖数据处理的各个环节。
385 1
|
数据采集 监控 JavaScript
工厂生产管理系统MES十大核心功能模块
MES提供了对生产现场的实时可视化,帮助企业管理生产计划、物料追踪、工艺控制、产品质量和生产设备等方面的工作。
771 11

热门文章

最新文章