【人工智能】机器学习之暴力调参案例

简介: 暴力调参案例使用的数据集为 from sklearn.datasets import fetch_20newsgroups

暴力调参案例

使用的数据集为

from sklearn.datasets import fetch_20newsgroups

因为在线下载慢,可以提前下载保存到
在这里插入图片描述

首先引入所需库

import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import GridSearchCV
from sklearn.feature_selection import SelectKBest,chi2
import sklearn.metrics as metrics
from sklearn.datasets import fetch_20newsgroups
import sys

编码问题显示

if sys.getdefaultencoding() != defaultencoding:
    reload(sys)
    sys.setdefaultencoding(defaultencoding)
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False

如果报错的话可以改为

import importlib,sys

if sys.getdefaultencoding() != defaultencoding:
    importlib.reload(sys)
    sys.setdefaultencoding(defaultencoding)
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False

用来正常显示中文
mpl.rcParams['font.sans-serif']=[u'simHei']
用来正常正负号
mpl.rcParams['axes.unicode_minus']=False

获取数据

#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas=fetch_20newsgroups(data_home="./datas/",subset='train',categories=['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test=fetch_20newsgroups(data_home="./datas/",subset='test',categories=['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x=datas.data#获取新闻X
train_y=datas.target#获取新闻Y
test_x=datas_test.data#获取测试集的x
test_y=datas_test.target#获取测试集的y

自动调参

import time
def setParam(algo,name):
    gridSearch = GridSearchCV(algo,param_grid=[],cv=5)
    m=0
    if hasattr(algo,"alpha"):
        n=np.logspace(-2,9,10)
        gridSearch.set_params(param_grid={"alpha":n})
        m=10
    if hasattr(algo,"max_depth"):
        depth=[2,7,10,14,20,30]
        gridSearch.set_params(param_grid={"max_depth":depth})
        m=len(depth)
    if hasattr(algo,"n_neighbors"):
        neighbors=[2,7,10]
        gridSearch.set_params(param_grid={"n_neighbors":neighbors})
        m=len(neighbors)
    t1=time.time()
    gridSearch.fit(train_x,train_y)
    test_y_hat=gridSearch.predict(test_x)
    train_y_hat=gridSearch.predict(train_x)
    t2=time.time()-t1
    print(name, gridSearch.best_estimator_)
    train_error=1-metrics.accuracy_score(train_y,train_y_hat)
    test_error=1-metrics.accuracy_score(test_y,test_y_hat)
    return name,t2/5*m,train_error,test_error

选择算法调参

朴素贝叶斯,随机森林,KNN

algorithm=[("mnb",MultinomialNB()),("random",RandomForestClassifier()),("knn",KNeighborsClassifier())]
for name,algo in algorithm:
    result=setParam(algo,name)
    results.append(result)

可视化

#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names,times,train_err,test_err=[[x[i] for x in results] for i in  range(0,4)]

axes=plt.axes()
axes.bar(np.arange(len(names)),times,color="red",label="耗费时间",width=0.1)
axes.bar(np.arange(len(names))+0.1,train_err,color="green",label="训练集错误",width=0.1)
axes.bar(np.arange(len(names))+0.2,test_err,color="blue",label="测试集错误",width=0.1)
plt.xticks(np.arange(len(names)), names)
plt.legend()
plt.show()

代码整合:

#coding=UTF-8
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import GridSearchCV
from sklearn.feature_selection import SelectKBest,chi2
import sklearn.metrics as metrics
from sklearn.datasets import fetch_20newsgroups
import sys
import importlib,sys

if sys.getdefaultencoding() != defaultencoding:
    # reload(sys)
    importlib.reload(sys)
    sys.setdefaultencoding(defaultencoding)
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False

#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas=fetch_20newsgroups(data_home="./datas/",subset='train',categories=['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test=fetch_20newsgroups(data_home="./datas/",subset='test',categories=['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x=datas.data#获取新闻X
train_y=datas.target#获取新闻Y
test_x=datas_test.data#获取测试集的x
test_y=datas_test.target#获取测试集的y

tfidf=TfidfVectorizer(stop_words="english")
train_x=tfidf.fit_transform(train_x,train_y)#向量转化
test_x=tfidf.transform(test_x)#向量转化

print(train_x.shape)
best=SelectKBest(chi2,k=1000)#降维变成一千列

train_x = best.fit_transform(train_x,train_y)#转换
test_x = best.transform(test_x)

import time
def setParam(algo,name):
    gridSearch = GridSearchCV(algo,param_grid=[],cv=5)
    m=0
    if hasattr(algo,"alpha"):
        n=np.logspace(-2,9,10)
        gridSearch.set_params(param_grid={"alpha":n})
        m=10
    if hasattr(algo,"max_depth"):
        depth=[2,7,10,14,20,30]
        gridSearch.set_params(param_grid={"max_depth":depth})
        m=len(depth)
    if hasattr(algo,"n_neighbors"):
        neighbors=[2,7,10]
        gridSearch.set_params(param_grid={"n_neighbors":neighbors})
        m=len(neighbors)
    t1=time.time()
    gridSearch.fit(train_x,train_y)
    test_y_hat=gridSearch.predict(test_x)
    train_y_hat=gridSearch.predict(train_x)
    t2=time.time()-t1
    print(name, gridSearch.best_estimator_)
    train_error=1-metrics.accuracy_score(train_y,train_y_hat)
    test_error=1-metrics.accuracy_score(test_y,test_y_hat)
    return name,t2/5*m,train_error,test_error
results=[]
plt.figure()
algorithm=[("mnb",MultinomialNB()),("random",RandomForestClassifier()),("knn",KNeighborsClassifier())]
for name,algo in algorithm:
    result=setParam(algo,name)
    results.append(result)
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names,times,train_err,test_err=[[x[i] for x in results] for i in  range(0,4)]

axes=plt.axes()
axes.bar(np.arange(len(names)),times,color="red",label="耗费时间",width=0.1)
axes.bar(np.arange(len(names))+0.1,train_err,color="green",label="训练集错误",width=0.1)
axes.bar(np.arange(len(names))+0.2,test_err,color="blue",label="测试集错误",width=0.1)
plt.xticks(np.arange(len(names)), names)
plt.legend()
plt.show()

结果:

在这里插入图片描述

在这里插入图片描述

目录
相关文章
|
2天前
|
机器学习/深度学习 人工智能 TensorFlow
人工智能平台PAI产品使用合集之ev必须在特定的scope下定义吗
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
1天前
|
机器学习/深度学习 人工智能 算法
【机器学习】探究Q-Learning通过学习最优策略来解决AI序列决策问题
【机器学习】探究Q-Learning通过学习最优策略来解决AI序列决策问题
|
1天前
|
机器学习/深度学习 人工智能 测试技术
自动化测试中AI与机器学习的融合应用
【4月更文挑战第29天】 随着技术的不断进步,人工智能(AI)和机器学习(ML)在软件测试中的应用越来越广泛。本文将探讨AI和ML如何改变自动化测试领域,提高测试效率和质量。我们将讨论AI和ML的基本概念,以及它们如何应用于自动化测试,包括智能测试用例生成,缺陷预测,测试执行优化等方面。最后,我们还将讨论AI和ML在自动化测试中的挑战和未来发展趋势。
|
2天前
|
机器学习/深度学习 存储 人工智能
人工智能平台PAI产品使用合集之是否可以在模型部署发布后以http接口形式提供给业务开发人员使用
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
2天前
|
机器学习/深度学习 人工智能 运维
人工智能平台PAI产品使用合集之机器学习PAI可以通过再建一个done分区或者使用instance.status来进行部署吗
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
2天前
|
机器学习/深度学习 人工智能 网络协议
人工智能平台PAI 操作报错合集之在本地运行Alink Server时没有遇到问题。但是,当您尝试在PAI上运行时出现了错误。如何解决
阿里云人工智能平台PAI (Platform for Artificial Intelligence) 是阿里云推出的一套全面、易用的机器学习和深度学习平台,旨在帮助企业、开发者和数据科学家快速构建、训练、部署和管理人工智能模型。在使用阿里云人工智能平台PAI进行操作时,可能会遇到各种类型的错误。以下列举了一些常见的报错情况及其可能的原因和解决方法。
|
2天前
|
机器学习/深度学习 人工智能 分布式计算
人工智能平台PAI 操作报错合集之在本地构建easyrec docker镜像时遇到了无法连接docker服务如何解决
阿里云人工智能平台PAI (Platform for Artificial Intelligence) 是阿里云推出的一套全面、易用的机器学习和深度学习平台,旨在帮助企业、开发者和数据科学家快速构建、训练、部署和管理人工智能模型。在使用阿里云人工智能平台PAI进行操作时,可能会遇到各种类型的错误。以下列举了一些常见的报错情况及其可能的原因和解决方法。
|
2天前
|
机器学习/深度学习 人工智能 运维
人工智能平台PAI 操作报错合集之请问Alink的算法中的序列异常检测组件,是对数据进行分组后分别在每个组中执行异常检测,而不是将数据看作时序数据进行异常检测吧
阿里云人工智能平台PAI (Platform for Artificial Intelligence) 是阿里云推出的一套全面、易用的机器学习和深度学习平台,旨在帮助企业、开发者和数据科学家快速构建、训练、部署和管理人工智能模型。在使用阿里云人工智能平台PAI进行操作时,可能会遇到各种类型的错误。以下列举了一些常见的报错情况及其可能的原因和解决方法。
|
3天前
|
机器学习/深度学习 人工智能 算法
将 Visual Basic 与人工智能结合:机器学习的初步探索
【4月更文挑战第27天】本文探讨了Visual Basic(VB)在人工智能,尤其是机器学习领域的应用。VB作为易学易用的编程语言,结合机器学习可为开发者提供简单的人工智能实现途径。通过第三方库、调用外部程序或自行开发算法,VB能实现图像识别、文本分类和预测分析等功能。尽管面临性能、人才短缺和技术更新的挑战,但随着技术发展,VB在人工智能领域的潜力不容忽视,有望创造更多创新应用。
|
15天前
|
机器学习/深度学习 Python 索引
fast.ai 机器学习笔记(二)(4)
fast.ai 机器学习笔记(二)
24 0
fast.ai 机器学习笔记(二)(4)