K折交叉验证的原理以及实战&使用StratifiedKFold来实现分层抽样

本文涉及的产品
MSE Nacos/ZooKeeper 企业版试用,1600元额度,限量50份
服务治理 MSE Sentinel/OpenSergo,Agent数量 不受限
云原生网关 MSE Higress,422元/月
简介: K折交叉验证的原理以及实战&使用StratifiedKFold来实现分层抽样

前言


交叉验证的由来:在机器学习的过程中,我们不能将全部数据都用于数据的模型训练,否则会导致我们没有数据集对该模型进行验证,无法评估模型的预测效果。


一、交叉验证(Cross-Validation)


众所周知,模型训练的数据量越大时,通常训练出来的模型效果会越好,所以如何充分利用我们手头的数据呢?


1-1、LOOCV(Leave-One-Out Cross Validation)(留一交叉验证)


这个方法是将数据集分为训练集和测试集,只用一个数据作为测试集,其它的数据都作为训练集,并将此步骤重复N次。


05629aae9e8a4ccf8ae6cd6bf4734eb6.png



结果就是我们训练了n个模型,每次都得到一个MSE,计算最终的MSE就是将这n个MSE取平均。

缺点是计算量太大。


1-2、K-fold Cross Validation


为了解决LOOCV计算量太大的问题,我们提出了K折交叉验证,测试集不再只是包含一个数据,而是包含多个数据,具体数目根据K的选取而决定,比如说K=5。即:

1、将所有数据集分为5份。

2、不重复地每次取其中一份作为测试集,其它四份做训练集来训练模型,之后计算该模型在测试集上的MSE

3、5次的MSE取平均,就得到最后的MSE。


优点

1、相比于LOOCV,K折交叉验证的计算量小了很多,而且和LOOCV估计很相似,效果差不多

2、K折交叉验证可以有效的避免过拟合和欠拟合的发生。


1-3、k的选取


根据经验,k一般都选择为5或者是10。

也可以通过网格搜索来确定最佳的参数


1-4、k折交叉验证的作用


1、可以有效的避免过拟合的情况

2、在各种比赛的过程中,常常会遇到数据量不够大的情况,那么这种技巧可以帮助提高精度!


二、K折交叉验证实战。


2-1、K折交叉验证实战


# 导入包
from sklearn.model_selection import KFold
import numpy as np
# 构建数据集
X = np.arange(24).reshape(12,2)
print(X)
# KFold()
# 参数:
# n_splits: 分为几折交叉验证
# shuffle: 是否随机,设置为True后每次的结果都不一样。
# random_state: 设置随机因子,设置了这个参数之后,每次生成的结果是一样的,而且设置了random_state之后就没必要设置shuffle了。
kf = KFold(n_splits=3,shuffle=True)
for train,test in kf.split(X):
    # 返回值是元组,训练集和验证集组成的元组
    print('%s %s' % (train, test))

输出:

[[ 0 1]

[ 2 3]

[ 4 5]

[ 6 7]

[ 8 9]

[10 11]

[12 13]

[14 15]

[16 17]

[18 19]

[20 21]

[22 23]]

[ 0 1 2 3 8 9 10 11] [4 5 6 7]

[ 0 4 5 6 7 8 9 10] [ 1 2 3 11]

[ 1 2 3 4 5 6 7 11] [ 0 8 9 10]


三、使用StratifiedKFold(分层K折交叉验证器)实现分层抽样


Tips:使用StratifiedKFold可以实现分层抽样方法,StratifiedKFold是K-fold的变种。(解决训练集和测试集分布不一致的问题)

import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8], [5, 6], [7, 8]])
y = np.array([0, 0, 1, 1, 2, 2, 3, 3])
skf = StratifiedKFold(n_splits=2).split(X, y)
#c= skf.get_n_splits(X, y)
for train_index, test_index in skf:
     print("TRAIN:", train_index, "TEST:", test_index)
     X_train, X_test = X[train_index], X[test_index]
     y_train, y_test = y[train_index], y[test_index]
# 注意:这里输出的是索引,想输出分割后的结果,直接输出X_train,X_test ,y_train,y_test 就可以。
# 可以看到分割后训练集和测试集的分布是相同的。

输出

TRAIN: [1 3 5 7] TEST: [0 2 4 6]

TRAIN: [0 2 4 6] TEST: [1 3 5 7]

直接使用KFold来输出


import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8], [5, 6], [7, 8]])
y = np.array([0, 0, 1, 1, 2, 2, 3, 3])
skf = KFold(n_splits=2).split(X, y)
#c= skf.get_n_splits(X, y)
for train_index, test_index in skf:
     print("TRAIN:", train_index, "TEST:", test_index)
     X_train, X_test = X[train_index], X[test_index]
     y_train, y_test = y[train_index], y[test_index]
     print(X_train, y_train) 
     print(X_test, y_test)
     # 可以看到分布是不均匀的。


输出

TRAIN: [4 5 6 7] TEST: [0 1 2 3]

[[5 6]

[7 8]

[5 6]

[7 8]] [2 2 3 3]

[[1 2]

[3 4]

[1 2]

[3 4]] [0 0 1 1]

TRAIN: [0 1 2 3] TEST: [4 5 6 7]

[[1 2]

[3 4]

[1 2]

[3 4]] [0 0 1 1]

[[5 6]

[7 8]

[5 6]

[7 8]] [2 2 3 3]

参考文章

【机器学习】Cross-Validation(交叉验证)详解.

交叉验证(Cross Validation).

python 利用sklearn.cross_validation的KFold构造交叉验证数据集.

K折交叉验证法原理及python实现.

sklearn官方文档.


总结


每天中午都睡不醒,很烦。

相关文章
|
机器学习/深度学习 算法 数据挖掘
交叉验证之KFold和StratifiedKFold的使用(附案例实战)
交叉验证之KFold和StratifiedKFold的使用(附案例实战)
1704 0
|
机器学习/深度学习 数据采集 人工智能
Machine Learning机器学习之贝叶斯网络(BayesianNetwork)
Machine Learning机器学习之贝叶斯网络(BayesianNetwork)
|
机器学习/深度学习 存储 监控
Elasticsearch 在日志分析中的应用
【9月更文第2天】随着数字化转型的推进,日志数据的重要性日益凸显。日志不仅记录了系统的运行状态,还提供了宝贵的洞察,帮助企业改进产品质量、优化用户体验以及加强安全防护。Elasticsearch 作为一个分布式搜索和分析引擎,因其出色的性能和灵活性,成为了日志分析领域的首选工具之一。本文将探讨如何使用 Elasticsearch 作为日志分析平台的核心组件,并详细介绍 ELK(Elasticsearch, Logstash, Kibana)栈的搭建和配置流程。
823 4
|
12月前
|
机器学习/深度学习 计算机视觉 Python
模型预测笔记(三):通过交叉验证网格搜索机器学习的最优参数
本文介绍了网格搜索(Grid Search)在机器学习中用于优化模型超参数的方法,包括定义超参数范围、创建参数网格、选择评估指标、构建模型和交叉验证策略、执行网格搜索、选择最佳超参数组合,并使用这些参数重新训练模型。文中还讨论了GridSearchCV的参数和不同机器学习问题适用的评分指标。最后提供了使用决策树分类器进行网格搜索的Python代码示例。
1044 1
|
机器学习/深度学习 索引
|
12月前
|
数据采集 传感器 大数据
大数据中数据采集 (Data Collection)
【10月更文挑战第17天】
661 2
|
机器学习/深度学习 计算机视觉 文件存储
【轻量化网络系列(3)】MobileNetV3论文超详细解读(翻译 +学习笔记+代码实现)
【轻量化网络系列(3)】MobileNetV3论文超详细解读(翻译 +学习笔记+代码实现)
5843 0
【轻量化网络系列(3)】MobileNetV3论文超详细解读(翻译 +学习笔记+代码实现)
|
机器学习/深度学习 自然语言处理 数据可视化
LlamaFactory可视化微调大模型 - 参数详解
LlamaFactory可视化微调大模型 - 参数详解
2815 4
|
机器学习/深度学习 数据可视化 大数据
K值进行交叉验证
8月更文挑战第16天
|
机器学习/深度学习 数据采集 大数据