伪标签:教你玩转无标签数据的半监督学习方法

简介: 对于机器学习项目而言,数据是根本,但是往往我们拿到的是无标签数据,对于这些数据,我们该如何更好的利用它们呢?在本文中,作者提出了一个名为伪标签的半监督学习方法,通过这个方法,我们就可以使用无标签数据来提高机器学习模型的性能,也会让你在更多像Kaggle一样的比赛中受益。

更多深度文章,请关注云计算频道:https://yq.aliyun.com/cloud

对于每个机器学习项目而言,数据是基础,是不可或缺的一部分。在本文中,作者将会展示一个名为伪标签的简单的半监督学习方法,它可以通过使用无标签数据来提高机器学习模型的性能。

伪标签

为了训练机器学习模型,在监督学习中,数据必须是有标签的。那这是否意味着无标签的数据对于诸如分类和回归之类的监督任务就无用了呢?当然不是! 除了使用额外数据进行数据分析,还可以将无标签数据和标签数据结合起来,一同训练半监督学习模型。

4dadc981fd1da0e5e9956ed8ff1e96ef1bac4f67

该方法的主旨思想其实很简单。首先,在标签数据上训练模型,然后使用经过训练的模型来预测无标签数据的标签,从而创建伪标签。此外,将标签数据和新生成的伪标签数据结合起来作为新的训练数据。

这个方法的灵感来自于fast.ai MOOC(原文)。虽然这个方法是在深度学习(在线算法)的背景下提到的,但是我们在传统的机器学习模型上进行了尝试,并得到了细微的提升。

数据预处理与探索

通常在像Kaggle这样的比赛中,参赛者通常收到的数据是有标签的数据作为训练集,无标签的数据作为测试集。这是一个测试伪标签的好地方。我们这里使用的数据集来自Mercedes-Benz Greener Manufacturing competition,该竞赛的目标是根据提供的特征(回归)测试一辆汽车的持续时间。与往常一样,在本笔记本中可以找到附加描述的所有代码。

import pandas as pd

# Load the data
train = pd.read_csv('input/train.csv')
test = pd.read_csv('input/test.csv')

print(train.shape, test.shape)
# (4209, 378) (4209, 377)

从上面我们可以看到,训练数据并不理想,只有4209组数据,376个特征。为了改善数据集,我们应该减少特征数据,尽可能地增加数据量。我在之前的一篇博客文章中提到过特征的重要性(特征压缩),这个主题暂且略过不谈,因为这篇博客文章的主要重点将是增加带有伪标签的数据量。这个数据集可以很好地用于伪标签,因为小数据中有无标签的数据比例为1:1。

下表展示的是整个训练集的子集,特征x0-x8是分类变量,我们必须把它们转换成模型可用的数值变量。

dd1b618ac3477394c50c8b2eec1d2438d4e36b00

这里使用scikit- learn的LabelEncoder类完成的。

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

features = train.columns[2:]

for column_name in features:
    label_encoder = LabelEncoder() 
    
    # Get the column values
    train_column_values = list(train[column_name].values)
    test_column_values = list(test[column_name].values)
    
    # Fit the label encoder
    label_encoder.fit(train_column_values + test_column_values)
    
    # Transform the feature
    train[column_name] = label_encoder.transform(train_column_values)
    test[column_name] = label_encoder.transform(test_column_values)

结果如下:

276dc4502a3f26e3138ca3a2039bec0aaf75c561

现在,用于机器学习模型的数据就准备好了。

使用Python和scikit-learn实现伪标签

我们创建一个函数,包含伪标签数据和标签数据的“增强训练集“。函数的参数包括模型、训练集、测试集信息(数据和特征)和参数sample_rate。Sample_rate允许我们控制混合有真实标签数据的伪标签数据的百分比。将sample_rate设置为0.0意味着模型只使用真实标签的数据,而sample_rate为0.5时意味着模型使用了所有的真实的标签数据和一半的伪标签数据。无论哪种情况,模型都将使用所有真实标签的数据。

def create_augmented_train(X, y, model, test, features, target, sample_rate):
    '''
    Create and return the augmented_train set that consists
    of pseudo-labeled and labeled data.
    '''
    num_of_samples = int(len(test) * sample_rate)

    # Train the model and creat the pseudo-labeles
    model.fit(X, y)
    pseudo_labeles = model.predict(test[features])

    # Add the pseudo-labeles to the test set
    augmented_test = test.copy(deep=True)
    augmented_test[target] = pseudo_labeles

    # Take a subset of the test set with pseudo-labeles and append in onto
    # the training set
    sampled_test = augmented_test.sample(n=num_of_samples)
    temp_train = pd.concat([X, y], axis=1)
    augemented_train = pd.concat([sampled_test, temp_train])
    
    # Shuffle the augmented dataset and return it
    return shuffle(augemented_train)

此外,我们还需要一个可以接受增强训练集的方法来训练模型。这是另一个函数,我们在准备参数之前已经写过了。这是一个很好的机会,可以创建一个类来增强内聚性,使代码更简洁,并且把方法放入这个类中。我们将要创建的类叫PseudoLabeler.。这个类将采用scikit-learn模型,并利用增强训练集来训练它。Scikit-learn允许我们创建自己的回归类库,但是我们必须遵守他们的库标准。

from sklearn.utils import shuffle
from sklearn.base import BaseEstimator, RegressorMixin

class PseudoLabeler(BaseEstimator, RegressorMixin):
    
    def __init__(self, model, test, features, target, sample_rate=0.2, seed=42):
        self.sample_rate = sample_rate
        self.seed = seed
        self.model = model
        self.model.seed = seed
        
        self.test = test
        self.features = features
        self.target = target
        
    def get_params(self, deep=True):
        return {
            "sample_rate": self.sample_rate,
            "seed": self.seed,
            "model": self.model,
            "test": self.test,
            "features": self.features,
            "target": self.target
        }

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

        
    def fit(self, X, y):
        if self.sample_rate > 0.0:
            augemented_train = self.__create_augmented_train(X, y)
            self.model.fit(
                augemented_train[self.features],
                augemented_train[self.target]
            )
        else:
            self.model.fit(X, y)
        
        return self


    def __create_augmented_train(self, X, y):
        num_of_samples = int(len(test) * self.sample_rate)
        
        # Train the model and creat the pseudo-labels
        self.model.fit(X, y)
        pseudo_labels = self.model.predict(self.test[self.features])
        
        # Add the pseudo-labels to the test set
        augmented_test = test.copy(deep=True)
        augmented_test[self.target] = pseudo_labels
        
        # Take a subset of the test set with pseudo-labels and append in onto
        # the training set
        sampled_test = augmented_test.sample(n=num_of_samples)
        temp_train = pd.concat([X, y], axis=1)
        augemented_train = pd.concat([sampled_test, temp_train])

        return shuffle(augemented_train)
        
    def predict(self, X):
        return self.model.predict(X)
    
    def get_model_name(self):
        return self.model.__class__.__name__

除“fit”和“__create_augmented_train”方法以外,scikit-learn还需要一些较小的方法来使用这个类作为回归类库(可从官方文档了解更多信息)。现在我们已经为伪标签创建了scikit-learn类,我们来举个例子。

target = 'y'

# Preprocess the data
X_train, X_test = train[features], test[features]
y_train = train[target]

# Create the PseudoLabeler with XGBRegressor as the base regressor
model = PseudoLabeler(
    XGBRegressor(nthread=1),
    test,
    features,
    target
)

# Train the model and use it to predict
model.fit(X_train, y_train)
model.predict(X_train)

在这个例子中,PseudoLabeler类使用了XGBRegressor来实现伪标签的回归。Sample_rate参数的默认值为0.2,意味着PseudoLabeler将会使用20%的无标签数据集。

结果

为了测试PseudoLabeler,我使用XGBoost(当现场比赛时,使用XGBoost会得到最好的结果)。为了评估模型,我们将原始XGBoost与伪标签XGBoost进行比较。使用8折交叉验证(在4k数据量上,每折都有一个小数据集——大约500个数据)。评估指标是r2 - score,即比赛的官方评价指标。

2a72f4ad68e971539db71362a2d38c9aab0d87dc

PseudoLabeler的平均分略高,偏差较低,这使它(略微)优于原始模型。我在笔记本上做了一个更详细的分析,可以在这里看到。性能增长也许不高,但是要记住,Kaggle比赛中,每增加一个分数都有可能使你在排行榜上排名更高。这里介绍的复杂性并不是太大(70行左右代码),但是这个示例中的问题和模型都很简单,当试图使用这个方法解决更复杂的问题或领域时要切记。

结论

伪标签允许我们在训练机器学习模型的同时使用伪标签数据。这听起来像是一种强大的技术,是的,它经常会增加我们的模型性能。然而,它可能很难调整以使它正常工作,即使它有效,也会带来轻微的性能提升。在像Kaggle这样的比赛中,我相信这项技术是很有用的,因为通常即使是轻微的分数提高也能让你在排行榜上得到提升。尽管如此,在生产环境中使用这种方法之前,我还是会再三考虑,因为它似乎在没有大幅度提高性能的情况下带来了额外的复杂性,而这可能不是我们所希望看到的。

 

作者介绍:Vinko Kodžoman,数据和软件爱好者,游戏玩家和冒险家

f8b83ab92abe6cad75c9be95cf67894ff2d0d9f2

Githubhttps://github.com/Weenkus

Emailvinko.kodzoman@yahoo.com

https://datawhatnow.com/

 

以上为译文

本文由北邮@爱可可-爱生活老师推荐阿里云云栖社区组织翻译。

文章原标题《Pseudo-labeling a simple semi-supervised learning method》,作者:Vinko Kodžoman

译者:郭翔云, 审校:袁虎。

文章为简译,更为详细的内容,请查看原文

本文由用户为个人学习及研究之目的自行翻译发表,如发现侵犯原作者的版权,请与社区联系处理yqgroup@service.aliyun.com

 

相关文章
|
机器学习/深度学习 自然语言处理 算法
【多标签文本分类】《多粒度信息关系增强的多标签文本分类》
提出一种多粒度的多标签文本分类方法。一共3个粒度:文档级分类模块、词级分类模块、标签约束性关系匹配辅助模块。
183 0
|
移动开发 文字识别 算法
论文推荐|[PR 2019]SegLink++:基于实例感知与组件组合的任意形状密集场景文本检测方法
本文简要介绍Pattern Recognition 2019论文“SegLink++: Detecting Dense and Arbitrary-shaped Scene Text by Instance-aware Component Grouping”的主要工作。该论文提出一种对文字实例敏感的自下而上的文字检测方法,解决了自然场景中密集文本和不规则文本的检测问题。
1956 0
论文推荐|[PR 2019]SegLink++:基于实例感知与组件组合的任意形状密集场景文本检测方法
|
4月前
|
前端开发 BI
前端基础(十)_标签分类(行级标签、块级标签、行块标签)
本文阐述了HTML标签的分类,包括行级标签、块级标签和行块标签,并展示了如何使用CSS的display属性实现标签类型之间的转换。
94 3
|
5月前
|
SQL 自然语言处理 算法
评估数据集CGoDial问题之计算伪OOD样本的软标签的问题如何解决
评估数据集CGoDial问题之计算伪OOD样本的软标签的问题如何解决
|
5月前
|
数据采集 机器学习/深度学习 算法
5.2.3 检测头设计(计算预测框位置和类别)
这篇文章详细介绍了YOLOv3目标检测模型中的检测头设计,包括预测框是否包含物体的概率计算、预测物体的位置和形状、预测物体类别的概率,并展示了如何通过网络输出得到预测值,以及如何建立损失函数来训练模型。
|
5月前
|
SQL 自然语言处理 算法
预训练模型STAR问题之计算伪OOD样本的软标签的问题如何解决
预训练模型STAR问题之计算伪OOD样本的软标签的问题如何解决
|
机器学习/深度学习 算法 计算机视觉
【多标签文本分类】层次多标签文本分类方法
【多标签文本分类】层次多标签文本分类方法
780 0
【多标签文本分类】层次多标签文本分类方法
|
计算机视觉
【多标签文本分类】《采用平衡函数的大规模多标签文本分类》
使用最常见的BERT+fc的多标签文本分类模型,只是改进了一下损失函数。
104 0
|
机器学习/深度学习 算法 数据挖掘
书写自动智慧文本分类器的开发与应用:支持多分类、多标签分类、多层级分类和Kmeans聚类
书写自动智慧文本分类器的开发与应用:支持多分类、多标签分类、多层级分类和Kmeans聚类
书写自动智慧文本分类器的开发与应用:支持多分类、多标签分类、多层级分类和Kmeans聚类
|
数据采集 机器学习/深度学习 自然语言处理
实现文本数据数值化、方便后续进行回归分析等目的,需要对文本数据进行多标签分类和关系抽取
实现文本数据数值化、方便后续进行回归分析等目的,需要对文本数据进行多标签分类和关系抽取
216 0