Python3入门机器学习 - GridSearch探索最佳超参数与交叉验证-阿里云开发者社区

开发者社区> 人工智能> 正文

Python3入门机器学习 - GridSearch探索最佳超参数与交叉验证

简介: 这次我们依旧使用digits数据集 准备数据 %%time import sklearn.datasets import numpy as np from sklearn.

这次我们依旧使用digits数据集
准备数据

%%time
import sklearn.datasets
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

digits = sklearn.datasets.load_digits()

X = digits.data
y = digits.target

X_train,X_test,y_train,y_test = train_test_split(X,y)

引入GridSearchCV,准备params参数集合进行测试

from sklearn.model_selection import GridSearchCV

params = [
    {
        'weights':['distance'],
        'n_neighbors':[i for i in range(1,11)],
        'p':[i for i in range(1,6)]    
    },
    {
        'weights':['uniform'],
        'n_neighbors':[i for i in range(1,11)]
    }
]

knn_clf = KNeighborsClassifier()

grid_search = GridSearchCV(knn_clf,params)//传入knn算法对象和参数集合
/** grid_search对象 **/
grid_search.fit(X_train,y_train)  //传入数据集,这行代码我感觉运行了一年
grid_search.best_estimator_   //显示最佳参数模型
grid_search.best_params_  //显示最佳的超参数
grid_search.best_score_  //显示最佳的正确率

评价模型好坏的标准在grid_search中更为复杂,CV交叉验证。并不仅仅根据正确率的大小评价好坏。


创建GridSearchCV()对象时的部分参数

  • estimator 该参数是你要使用的模型算法
  • param_grid 要实验的超参数集合,如上文中的params
  • n_jobs 该值根据你的CPU核心数而定,传入-1自动适配当前CPU核心数
  • verbose 每次循环执行时输出当前循环的信息,常用2




验证数据集与交叉验证

验证数据集


img_4ff9cf544cbc4e6ba6e682fd5cbba0a8.png

将数据分为3份,其中一份作为验证数据集调整超参数,这样可以避免拟合数据的随机性导致的模型误差。


交叉验证


img_37beaac271535d7507ffde1d8530b517.png

将训练数据分为m份,每次选一份做验证数据集,其余训练模型,最大限度保证模型的准确性


使用scikitlearn中的交叉验证
from sklearn.model_selection import cross_val_score

knn_clf = KNeighborsClassifier()
cross_val_score(knn_clf,X_train,y_train)
使用交叉验证获得最佳超参数
best_k = -1
best_score = -1
best_p=-1
for p in range(1,6):
    for k in range(1,11):
        knn_clf = KNeighborsClassifier(n_neighbors=k,weights="distance",p=p)
        scores = cross_val_score(knn_clf, X_train,y_train)
        score = np.mean(scores)
        if(score>best_score):
            best_k = k
            best_score = score
            best_p = p
            
print("best_k = ",best_k)
print("best_score = ",best_score)
print("best_p = ",best_p)

版权声明:本文首发在云栖社区,遵循云栖社区版权声明:本文内容由互联网用户自发贡献,版权归用户作者所有,云栖社区不为本文内容承担相关法律责任。云栖社区已升级为阿里云开发者社区。如果您发现本文中有涉嫌抄袭的内容,欢迎发送邮件至:developer2020@service.aliyun.com 进行举报,并提供相关证据,一经查实,阿里云开发者社区将协助删除涉嫌侵权内容。

分享:
人工智能
使用钉钉扫一扫加入圈子
+ 订阅

了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目

其他文章