sklearn.model_selection.learning_curve介绍(评估多大的样本量用于训练才能达到最佳效果)

简介: sklearn.model_selection.learning_curve介绍(评估多大的样本量用于训练才能达到最佳效果)

前言


学习曲线函数:可以用于检验数据是否过拟合,并且可以评估多大的样本量用于训练才能达到最佳效果(了解数据如何影响模型的性能)。还可以用于测试模型的超参数。


一、learning_curve介绍


learning_curve函数介绍: 用于确定不同训练集大小的交叉验证训练和测试分数,交叉验证生成器在训练和测试数据中对整个数据集进行k次拆分。将使用具有不同大小的训练集的子集来训练估计器,并将计算每个训练子集大小的分数和测试集。之后,将对每个训练子集大小的所有k次运行的分数求平均。

sklearn.model_selection.learning_curve(
  estimator, X, y, *, groups=None, 
  train_sizes=array([0.1, 0.33, 0.55, 0.78, 1.]), 
  cv=None, scoring=None, exploit_incremental_learning=False, 
  n_jobs=None, pre_dispatch='all', 
  verbose=0, shuffle=False, random_state=None, 
  error_score=nan, return_times=False, 
  fit_params=None)


注意:

1、当训练集的准确率比其他独立数据集上的测试结果的准确率要高时,一般都是过拟合。

2、当训练集和验证集的准确率都很低,很可能是欠拟合。


一些常用参数:

1、estimator:传入的模型对象。

2、X:传入的特征

3、y:传入的标签

4、train_sizes:数组,代表训练示例的相对或绝对数量,将用于生成学习曲线。如果dtype为float,则视为训练集最大尺寸的一部分(由所选的验证方法确定),即,它必须在(0,1]之内,否则将被解释为绝对大小注意,为了进行分类,样本的数量通常必须足够大,以包含每个类中的至少一个样本(默认值:np.linspace(0.1,1.0,5))

5、cv

cv的可能输入是:

1)None,要使用默认的三折交叉验证(v0.22版本中将改为五折)

2)整数,用于指定(分层)KFold中的折叠数,比如说10

3)CV splitter:分割器,例如我们这里用到的ShuffleSplit

6、n_jobs:运行的cpu个数。 -1表示使用所有处理器。

7、random_state:随机数种子。


返回值:

train_sizes_abs:返回生成的训练样本数量列表。

train_scores:数组,形状(n_ticks,n_cv_folds),训练集得分列表。

test_scores:数组,形状(n_ticks,n_cv_folds)测试集得分列表。


二、实战


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
                        n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):
    '''
    定义画出学习曲线的方法,核心是调用learning_curve方法。
    '''
    plt.figure()
    plt.title(title)
    if ylim is not None:
        plt.ylim(*ylim)
    plt.xlabel("Training examples")
    plt.ylabel("Score")
    # 这里交叉验证调用的必须是ShuffleSplit。
    # n_jobs: 要并行运行的作业数,这里默认为1,如果-1则表示使用所有处理器。
    train_sizes, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
    # 返回值:
    # train_sizes: 训练示例的等分比例,代表着横坐标有几个点。
    # train_scores: 训练集得分: 二维数组,shape:(5, 100)
    # test_scores:测试集得分
#     print(train_scores)
#     print(test_scores)
    # 最后求得五个平均值
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)
    print(train_sizes)
    print(train_scores_mean)
    print(test_scores_mean)
    plt.grid()
    plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
                     train_scores_mean + train_scores_std, alpha=0.1,
                     color="r")
    plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, alpha=0.1, color="g")
    plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
             label="Training score")
    plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
             label="Cross-validation score")
    plt.legend(loc="best")
    return plt
# X是2310行,37列的数据
X = train_data.values
y = train_target.values
# 图一
title = r"LinearRegression"
# ShuffleSplit:将样例打散,随机取出20%的数据作为测试集,这样取出100次,
cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0)
estimator = LinearRegression()    #建模
plot_learning_curve(estimator, title, X, y, ylim=(0.5, 1), cv=cv, n_jobs=1)


输出


3bdb2871478540ca9518afd427f60a73.png


参考文章:

Sklearn — 检视过拟合Learning curve.

官方文档.

sklearn中的学习曲线learning_curve函数.


总结


好耶好耶。

相关文章
|
7月前
|
机器学习/深度学习
大模型训练loss突刺原因和解决办法
【1月更文挑战第19天】大模型训练loss突刺原因和解决办法
1154 1
大模型训练loss突刺原因和解决办法
|
7月前
|
机器学习/深度学习 监控 数据可视化
训练损失图(Training Loss Plot)
训练损失图(Training Loss Plot)是一种在机器学习和深度学习过程中用来监控模型训练进度的可视化工具。损失函数是衡量模型预测结果与实际结果之间差距的指标,训练损失图展示了模型在训练过程中,损失值随着训练迭代次数的变化情况。通过观察损失值的变化,我们可以评估模型的拟合效果,调整超参数,以及确定合适的训练停止条件。
1304 5
|
4月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
37 0
|
7月前
|
机器学习/深度学习 算法 数据可视化
模型训练(Model Training)
模型训练(Model Training)是指使用数据集对模型进行训练,使其能够从数据中学习到特征和模式,进而完成特定的任务。在深度学习领域,通常使用反向传播算法来训练模型,其中模型会根据数据集中的输入和输出,不断更新其参数,以最小化损失函数。
612 1
|
机器学习/深度学习
推理(Inference)与预测(Prediction)
推理(Inference)与预测(Prediction)
521 1
推理(Inference)与预测(Prediction)
|
机器学习/深度学习 Web App开发 人工智能
一个项目帮你了解数据集蒸馏Dataset Distillation
一个项目帮你了解数据集蒸馏Dataset Distillation
283 0
|
机器学习/深度学习 数据中心
基于Fashion-MNIST数据集的模型剪枝(下)
1. 介绍 1.1 背景介绍 目前在深度学习中存在一些困境,对于移动是设备来说,主要是算不好;穿戴设备算不来;数据中心,大多数人又算不起 。这就是做模型做压缩与加速的初衷。
145 0
基于Fashion-MNIST数据集的模型剪枝(下)
|
机器学习/深度学习 存储 算法
基于Fashion-MNIST数据集的模型剪枝(上)
1. 介绍 1.1 背景介绍 目前在深度学习中存在一些困境,对于移动是设备来说,主要是算不好;穿戴设备算不来;数据中心,大多数人又算不起 。这就是做模型做压缩与加速的初衷。
488 0
基于Fashion-MNIST数据集的模型剪枝(上)
ML之LassoR&RidgeR:基于datasets糖尿病数据集利用LassoR和RidgeR算法(alpha调参)进行(9→1)回归预测
ML之LassoR&RidgeR:基于datasets糖尿病数据集利用LassoR和RidgeR算法(alpha调参)进行(9→1)回归预测
ML之LassoR&RidgeR:基于datasets糖尿病数据集利用LassoR和RidgeR算法(alpha调参)进行(9→1)回归预测
|
机器学习/深度学习 自然语言处理 算法
深度学习Loss合集:一文详解Contrastive Loss/Ranking Loss/Triplet Loss等区别与联系
深度学习Loss合集:一文详解Contrastive Loss/Ranking Loss/Triplet Loss等区别与联系
1574 0
深度学习Loss合集:一文详解Contrastive Loss/Ranking Loss/Triplet Loss等区别与联系

相关实验场景

更多