Scikit-learn Pipeline完全指南:高效构建机器学习工作流

简介: Scikit-learn管道是构建高效、鲁棒、可复用的机器学习工作流程的利器。通过掌握管道的使用,我们可以轻松地完成从数据预处理到模型训练、评估和部署的全流程,极大地提高工作效率。

在机器学习工作流程中,组合估计器通过将多个转换器(Transformer)和预测器(Predictor)整合到一个管道(Pipeline)中,可以有效简化整个过程。这种方法不仅简化了数据预处理环节,还能确保处理过程的一致性,最大限度地降低数据泄露的风险。构建组合估计器最常用的工具是Scikit-learn提供的Pipeline类。

关键术语

估计器(Estimator)泛指任何实现了

fit

方法的对象,该方法可以从数据中学习参数。估计器的概念涵盖了模型、预处理器以及管道等多种类型。

转换器(Transformer)是一种特殊的估计器,主要用于数据预处理或特征工程。转换器同时实现了

fit

方法(从数据中学习转换规则)和

transform

方法(将学习到的转换规则应用到数据上)。常见的转换器包括缩放器(Scaler)、降维器(Dimensionality Reduction)、编码器(Encoder)等。

预测器(Predictor)是用于监督学习任务(如分类或回归)的一类估计器。预测器需要实现

fit

方法(用于在训练集上学习)和

predict

方法(用于在测试集上进行预测)。

管道(Pipeline)

管道(或者叫流水线)可以将多个估计器串联起来,形成一个完整的工作流程。在数据处理过程中通常需要遵循一系列固定的步骤,例如特征选择、数据归一化以及模型训练等,所以一般会用这种形式来串联我们的训练过程。

使用管道有以下几个主要目的:

  • 便捷性和封装性: 只需调用一次fitpredict方法,即可完成从数据预处理到模型训练的全部步骤。
  • 联合参数选择: 可以使用网格搜索等方法,一次性地对管道中所有估计器的超参数进行优化。
  • 避免数据泄露: 通过交叉验证等方式,管道可以有效防止在训练过程中发生数据泄露。

管道中除最后一个估计器以外,其余估计器都必须是转换器(即必须实现

transform

方法)。最后一个估计器可以是任意类型,包括转换器、分类器等。

构建管道

构建管道需要提供一个由

(key, value)

元组组成的列表,其中

key

是字符串类型,表示当前步骤的名称;

value

则是一个估计器对象。下面是一个构建管道的示例:

 fromsklearn.pipelineimportPipeline
 fromsklearn.linear_modelimportLogisticRegression
 fromsklearn.decompositionimportPCA

 pipeline=Pipeline([
     ('transformer_1', StandardScaler()),
     ('predictor', LogisticRegression())
 ])

 pipeline

在上述示例中,我们首先使用

StandardScaler

对数据进行标准化处理,确保所有特征都经过适当的缩放。然后再将

LogisticRegression

模型作为预测器,对数据进行二分类。通过管道可以方便地对整个训练集进行拟合和预测,代码如下所示:

 # 拟合管道
 pipeline.fit(X_train, y_train)

 # 管道预测
 y_pred=pipeline.predict(X_test)

在拟合阶段,训练数据将依次通过管道中的各个转换器,依次完成拟合和转换操作。处理后的数据最终被用于训练预测模型。在预测阶段,管道会对测试数据应用与训练时相同的转换操作,再由预测器给出最终的预测结果。

网格搜索与交叉验证

手动调优超参数费时费力,而且往往难以取得理想的效果。这时就可以借助Scikit-learn提供的GridSearchCV类,自动化地搜索最优超参数组合。

 fromsklearn.model_selectionimportGridSearchCV

 # 定义网格搜索参数
 grid_params= {
     'transformer_1__with_mean': [True, False],
     'predictor__C': [0.1, 1, 10]
 }

 # 执行网格搜索
 grid=GridSearchCV(pipeline, grid_params, cv=10)  
 grid.fit(X_train, y_train)

grid_params

字典中,指定了需要优化的超参数及其候选取值:

  • transformer_1__with_mean: 管道中transformer_1步骤的with_mean参数,取值为布尔类型。
  • predictor__C: 管道中predictor步骤的正则化强度C,取值为数值类型。
  • cv=10: 指定交叉验证的折数为10。

通过网格搜索可以找到模型在当前数据集上的最优超参数组合。这个过程可以确保管道在性能上得到充分优化。

保存和加载管道

一旦通过

GridSearchCV

完成了管道的训练和优化,就可以将其保存起来,供日后使用。下面的代码展示了如何保存和加载一个已经训练好的管道:

 importjoblib

 # 保存管道
 joblib.dump(pipeline, 'pipeline.pkl')

 # 加载管道
 loaded_pipeline=joblib.load('pipeline.pkl')

这一功能在实际生产环境中尤为重要。通过保存训练好的管道可以直接将其部署到线上系统,用于对新数据进行实时预测,而无需重新训练模型。

为什么要保存管道?

保存管道有以下几个主要原因:

  • 复用性: 避免了每次使用都需要重新训练模型和执行数据预处理的繁琐步骤。
  • 一致性: 确保对不同数据集应用相同的转换操作和模型,提高结果的可重复性。
  • 部署便捷: 将管道整体保存为一个对象,可以方便地集成到生产系统中,实现实时预测。
  • 时间效率: 对于复杂的管道或大规模数据集,重用已训练的管道可以显著节省计算时间。

完整示例代码

下面的代码展示了如何使用Scikit-learn管道完成端到端的机器学习流程:

  1. 定义包含数据转换和模型的管道;
  2. 使用GridSearchCV搜索最优超参数,并拟合管道;
  3. 使用训练好的管道对测试集进行预测。
 fromsklearn.pipelineimportPipeline
 fromsklearn.linear_modelimportLogisticRegression
 fromsklearn.decompositionimportPCA
 fromsklearn.model_selectionimportGridSearchCV

 # 创建管道
 pipeline=Pipeline([
     ('transformer_1', StandardScaler()),
     ('predictor', LogisticRegression())
 ])

 # 定义网格搜索参数
 grid_params= {
     'transformer_1__with_mean': [True, False],
     'predictor__C': [0.1, 1, 10]
 }

 # 执行网格搜索
 grid=GridSearchCV(pipeline, grid_params, cv=10)
 grid.fit(X_train, y_train)

 # 使用管道进行预测
 y_pred=pipeline.predict(X_test)

总结

Scikit-learn管道是构建高效、鲁棒、可复用的机器学习工作流程的利器。通过掌握管道的使用,我们可以轻松地完成从数据预处理到模型训练、评估和部署的全流程,极大地提高工作效率。建议在实际项目中多多尝试和运用管道,以期进一步优化您的机器学习流程。

https://avoid.overfit.cn/post/915632324fa14e3588539d4294f41077

Mohammed Shammeer

目录
相关文章
|
7月前
|
机器学习/深度学习 人工智能 Kubernetes
Argo Workflows 加速在 Kubernetes 上构建机器学习 Pipelines
Argo Workflows 是 Kubernetes 上的工作流引擎,支持机器学习、数据处理、基础设施自动化及 CI/CD 等场景。作为 CNCF 毕业项目,其扩展性强、云原生轻量化,受到广泛采用。近期更新包括性能优化、调度策略增强、Python SDK 支持及 AI/大数据任务集成,助力企业高效构建 AI、ML、Data Pipelines。
796 1
|
8月前
|
机器学习/深度学习 存储 运维
机器学习异常检测实战:用Isolation Forest快速构建无标签异常检测系统
本研究通过实验演示了异常标记如何逐步完善异常检测方案和主要分类模型在欺诈检测中的应用。实验结果表明,Isolation Forest作为一个强大的异常检测模型,无需显式建模正常模式即可有效工作,在处理未见风险事件方面具有显著优势。
672 46
|
9月前
|
机器学习/深度学习 人工智能 算法
Scikit-learn:Python机器学习的瑞士军刀
想要快速入门机器学习但被复杂算法吓退?本文详解Scikit-learn如何让您无需深厚数学背景也能构建强大AI模型。从数据预处理到模型评估,从垃圾邮件过滤到信用风险评估,通过实用案例和直观图表,带您掌握这把Python机器学习的'瑞士军刀'。无论您是AI新手还是经验丰富的数据科学家,都能从中获取将理论转化为实际应用的关键技巧。了解Scikit-learn与大语言模型的最新集成方式,抢先掌握机器学习的未来发展方向!
1161 12
Scikit-learn:Python机器学习的瑞士军刀
|
8月前
|
存储 人工智能 运维
企业级MLOps落地:基于PAI-Studio构建自动化模型迭代流水线
本文深入解析MLOps落地的核心挑战与解决方案,涵盖技术断层分析、PAI-Studio平台选型、自动化流水线设计及实战构建,全面提升模型迭代效率与稳定性。
345 6
|
8月前
|
机器学习/深度学习 PyTorch API
昇腾AI4S图机器学习:DGL图构建接口的PyG替换
本文探讨了在图神经网络中将DGL接口替换为PyG实现的方法,重点以RFdiffusion蛋白质设计模型中的SE3Transformer为例。SE3Transformer通过SE(3)等变性提取三维几何特征,其图构建部分依赖DGL接口。文章详细介绍了两个关键函数的替换:`make_full_graph` 和 `make_topk_graph`。前者构建完全连接图,后者生成k近邻图。通过PyG的高效实现(如`knn_graph`),我们简化了图结构创建过程,并调整边特征处理逻辑以兼容不同框架,从而更好地支持昇腾NPU等硬件环境。此方法为跨库迁移提供了实用参考。
|
8月前
|
机器学习/深度学习 数据采集 分布式计算
阿里云PAI AutoML实战:20分钟构建高精度电商销量预测模型
本文介绍了如何利用阿里云 PAI AutoML 平台,在20分钟内构建高精度的电商销量预测模型。内容涵盖项目背景、数据准备与预处理、模型训练与优化、部署应用及常见问题解决方案,助力企业实现数据驱动的精细化运营,提升市场竞争力。
1361 0
|
4月前
|
机器学习/深度学习 数据采集 人工智能
【机器学习算法篇】K-近邻算法
K近邻(KNN)是一种基于“物以类聚”思想的监督学习算法,通过计算样本间距离,选取最近K个邻居投票决定类别。支持多种距离度量,如欧式、曼哈顿、余弦相似度等,适用于分类与回归任务。结合Scikit-learn可高效实现,需合理选择K值并进行数据预处理,常用于鸢尾花分类等经典案例。(238字)
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
1425 6
|
9月前
|
机器学习/深度学习 数据采集 人工智能
20分钟掌握机器学习算法指南
在短短20分钟内,从零开始理解主流机器学习算法的工作原理,掌握算法选择策略,并建立对神经网络的直观认识。本文用通俗易懂的语言和生动的比喻,帮助你告别算法选择的困惑,轻松踏入AI的大门。
594 8