网格搜索的原理以及实战以及相关API(GridSearchCV)

简介: 网格搜索的原理以及实战以及相关API(GridSearchCV)

前言


网格搜索是调参侠常用的一种调参手段


一、含义&优缺点&简单实现


含义:手动的给出一个模型中想要改动的所有参数,让程序帮助我们使用穷举法把所有参数组合运行一遍,选出最好的参数组合。一般和交叉验证搭配使用,因为使用交叉验证可以使得评分更加严谨。

优点:并行计算,速度很快

缺点:当参数量很多时,非常耗费计算资源。


1-1、网格搜索简单实现


代码介绍:使用鸢尾花数据集,嵌套两层for循环来遍历两个参数列表,在训练集上训练之后,用模型在测试集上找到最好的分数并且输出对应参数以及分数。


from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
iris = load_iris()
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=0)
print("Size of training set:{} size of testing set:{}".format(X_train.shape[0],X_test.shape[0]))
####   grid search start
best_score = 0
for gamma in [0.001,0.01,0.1,1,10,100]:
    for C in [0.001,0.01,0.1,1,10,100]:
        svm = SVC(gamma=gamma,C=C)#对于每种参数可能的组合,进行一次训练;
        svm.fit(X_train,y_train)
        score = svm.score(X_test,y_test)
        if score > best_score:#找到表现最好的参数
            best_score = score
            best_parameters = {'gamma':gamma,'C':C}
####   grid search end
print("Best score:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))


输出


Size of training set:112 size of testing set:38

Best score:0.97

Best parameters:{‘gamma’: 0.001, ‘C’: 100}


1-2、带有交叉验证的网格搜索


代码介绍:同样的使用鸢尾花数据集,使用两层for循环来赋值参数,不同的是每一层循环内使用对应的参数来做训练,并且使用交叉验证函数cross_val_score来得到一个训练的平均分数,循环结束,得到最好的参数,重新在训练集上训练,并且在测试集上得到分数。

from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
iris = load_iris()
X_trainval,X_test,y_trainval,y_test = train_test_split(iris.data,iris.target,random_state=0)
best_score = 0.0
for gamma in [0.001,0.01,0.1,1,10,100]:
    for C in [0.001,0.01,0.1,1,10,100]:
        svm = SVC(gamma=gamma,C=C)
        scores = cross_val_score(svm,X_trainval,y_trainval,cv=5) #5折交叉验证
        score = scores.mean() #取平均数
        if score > best_score:
            best_score = score
            best_parameters = {"gamma":gamma,"C":C}
svm = SVC(**best_parameters)
svm.fit(X_trainval,y_trainval)
test_score = svm.score(X_test,y_test)
print("Best score on validation set:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))
print("Score on testing set:{:.2f}".format(test_score))

输出

Best score on validation set:0.97

Best parameters:{‘gamma’: 0.1, ‘C’: 10}

Score on testing set:0.97


二、GridSearchCV(网格搜索&交叉验证)

2-1、GridSearchCV简介


含义:(sklearn类的一个方法)GridSearchCV既包含了网格搜索,又包含了交叉验证。只要输入参数列表,就可以保证在指定的参数范围内找到精度最高的参数,适合小型数据集,但是缺点是要遍历所有可能的参数组合的话,在面对大数据集和多参数的情况下,将会非常耗时。

补充:当数据量较大时,可以选择使用坐标下降法,即拿对模型影响较大的参数依次调优。

网格搜索:使用不同的参数组合来找到在验证集上精度最高的参数。

k折交叉验证:k折交叉验证将所有数据集分成k份,不重复地每次取其中一份做测试集,用其余k-1份做训练集训练模型,之后计算该模型在测试集上的得分,将k次的得分取平均得到最后的得分。


2-2、GridSearchCV方法

GridSearchCV参数说明


sklearn.model_selection.GridSearchCV(
  # 选择使用的分类器 
  estimator, 
  # 需要最优化的参数的取值,值为字典或者列表。
  param_grid, 
  *, 
  # 模型评价标准。
  scoring=None, 
  # 使用处理器的个数,默认为1,当为-1时,表示使用所有处理器。
  n_jobs=None, 
  # 默认为True,为True时,默认为各个样本fold概率分布一致。
  iid='deprecated', 
  # 默认为True,即在搜索参数结束后,用最佳参数结果再次fit一遍全部数据集。
  refit=True, 
  # 交叉验证参数,默认为5,即使用五折交叉验证。
  cv=None, 
  verbose=0, 
  pre_dispatch='2*n_jobs', 
  error_score=nan, 
  return_train_score=False
)


GridSearchCV属性以及方法说明


  • cv_results_ : dict of numpy (masked) ndarrays具有键作为列标题和值作为列的dict,可以导入到DataFrame中。注意,“params”键用于存储所有参数候选项的参数设置列表。
  • best_estimator_ : 最优模型以及对应的参数,如果refit = False,则不可用。
  • best_score_ :观察到的最好的评分。
  • best_parmas_ : 给出最佳结果的参数设置
  • best_index_ : int 对应于最佳候选参数设置的索引(cv_results_数组)search.cv_results _ [‘params’] [search.best_index_]中的dict给出了最佳模型的参数设置,给出了最高的平均分数(search.best_score_)
  • grid.fit(): 运行网格搜索
  • predict: 使用找到的最佳参数在估计器上调用预测
  • grid.score(): 模型在测试集上表现最好的分数。


2-3、GridSearchCV实战

from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
#把要调整的参数以及其候选值 列出来;
param_grid = {"gamma":[0.001,0.01,0.1,1,10,100],
             "C":[0.001,0.01,0.1,1,10,100]}
print("Parameters:{}".format(param_grid))
grid_search = GridSearchCV(SVC(),param_grid,cv=5) #实例化一个GridSearchCV类
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=10)
grid_search.fit(X_train,y_train) #训练,找到最优的参数,同时使用最优的参数实例化一个新的SVC estimator。
print("Test set score:{:.2f}".format(grid_search.score(X_test,y_test)))
print("Best parameters:{}".format(grid_search.best_params_))
print("Best score on train set:{:.2f}".format(grid_search.best_score_))


输出

Parameters:{‘gamma’: [0.001, 0.01, 0.1, 1, 10, 100], ‘C’: [0.001, 0.01, 0.1, 1, 10, 100]}

Test set score:0.97

Best parameters:{‘C’: 10, ‘gamma’: 0.1}

Best score on train set:0.98


参考文章:


sklearn中的GridSearchCV方法详解.

机器学习(四)——模型调参利器 gridSearchCV(网格搜索).

Python机器学习笔记:Grid SearchCV(网格搜索).

sklearn官网.


总结


今天是周日! 斗破更新了,休息的时候我要马上去看!😄


相关文章
|
5天前
|
JSON 数据管理 关系型数据库
【Dataphin V3.9】颠覆你的数据管理体验!API数据源接入与集成优化,如何让企业轻松驾驭海量异构数据,实现数据价值最大化?全面解析、实战案例、专业指导,带你解锁数据整合新技能!
【8月更文挑战第15天】随着大数据技术的发展,企业对数据处理的需求不断增长。Dataphin V3.9 版本提供更灵活的数据源接入和高效 API 集成能力,支持 MySQL、Oracle、Hive 等多种数据源,增强 RESTful 和 SOAP API 支持,简化外部数据服务集成。例如,可轻松从 RESTful API 获取销售数据并存储分析。此外,Dataphin V3.9 还提供数据同步工具和丰富的数据治理功能,确保数据质量和一致性,助力企业最大化数据价值。
19 1
|
28天前
|
前端开发 API 数据库
告别繁琐,拥抱简洁!Python RESTful API 设计实战,让 API 调用如丝般顺滑!
【7月更文挑战第23天】在Python的Flask框架下构建RESTful API,为在线商店管理商品、订单及用户信息。以商品管理为例,设计简洁API端点,如GET `/products`获取商品列表,POST `/products`添加商品,PUT和DELETE则分别用于更新和删除商品。使用SQLAlchemy ORM与SQLite数据库交互,确保数据一致性。实战中还应加入数据验证、错误处理和权限控制,使API既高效又安全,便于前端或其他服务无缝对接。
48 9
|
1天前
|
编译器 API Android开发
Android经典实战之Kotlin Multiplatform 中,如何处理不同平台的 API 调用
本文介绍Kotlin Multiplatform (KMP) 中使用 `expect` 和 `actual` 关键字处理多平台API调用的方法。通过共通代码集定义预期API,各平台提供具体实现,编译器确保正确匹配,支持依赖注入、枚举类处理等,实现跨平台代码重用与原生性能。附带示例展示如何定义跨平台函数与类。
7 0
|
1月前
|
JavaScript 应用服务中间件 API
Node.js搭建REST API实战:从基础到部署
【7月更文挑战第18天】通过以上步骤,你可以将你的Node.js REST API从开发环境顺利迁移到生产环境,并利用各种工具和技术来确保应用的稳定性、安全性和可扩展性。
|
1月前
|
前端开发 API 开发者
Python Web开发者必看!AJAX、Fetch API实战技巧,让前后端交互如丝般顺滑!
【7月更文挑战第13天】在Web开发中,AJAX和Fetch API是实现页面无刷新数据交换的关键。在Flask博客系统中,通过创建获取评论的GET路由,我们可以展示使用AJAX和Fetch API的前端实现。AJAX通过XMLHttpRequest发送请求,处理响应并在成功时更新DOM。Fetch API则使用Promise简化异步操作,代码更现代。这两个工具都能实现不刷新页面查看评论,Fetch API的语法更简洁,错误处理更直观。掌握这些技巧能提升Python Web项目的用户体验和开发效率。
44 7
|
1月前
|
安全 Java API
Nest.js 实战 (三):使用 Swagger 优雅地生成 API 文档
这篇文章介绍了Swagger,它是一组开源工具,围绕OpenAPI规范帮助设计、构建、记录和使用RESTAPI。文章主要讨论了Swagger的主要工具,包括SwaggerEditor、SwaggerUI、SwaggerCodegen等。然后介绍了如何在Nest框架中集成Swagger,展示了安装依赖、定义DTO和控制器等步骤,以及如何使用Swagger装饰器。文章最后总结说,集成Swagger文档可以自动生成和维护API文档,规范API标准化和一致性,但会增加开发者工作量,需要保持注释和装饰器的准确性。
Nest.js 实战 (三):使用 Swagger 优雅地生成 API 文档
|
18天前
|
JavaScript 前端开发 中间件
打造卓越后端:构建高效API的最佳实践与实战代码示例——解锁高性能Web服务的秘密
【8月更文挑战第2天】构建高效后端API:最佳实践与代码示例
35 0
|
26天前
|
API 开发者 Python
淘宝商品详情API接口开发实战
在电商领域,获取淘宝商品详情是关键需求。需先注册淘宝开放平台账号并创建应用,获取AppKey与AppSecret;随后申请商品服务API权限。利用Python,通过AppKey和AppSecret获取Access Token,进而调用商品详情API,需替换示例代码中的`your_app_key`, `your_app_secret`, `your_access_token`, 和`item_id`。注意遵守平台限制,处理可能的错误及合理规划调用策略以避免违规。[示例代码](https://)展示了从获取Access Token到调用商品详情API的全过程。
|
29天前
|
存储 JSON API
实战派教程!Python Web开发中RESTful API的设计哲学与实现技巧,一网打尽!
【7月更文挑战第22天】构建RESTful API实战:**使用Python Flask设计图书管理API,遵循REST原则,通过GET/POST/PUT/DELETE操作处理/books及/books/<id>。示例代码展示资源定义、请求响应交互。关键点包括HTTP状态码的使用、版本控制、错误处理和文档化。本文深入探讨设计哲学与实现技巧,助力理解RESTful API开发。
30 0