《python机器学习从入门到高级》之分类算法:下(含详细代码)

简介: 《python机器学习从入门到高级》之分类算法:下(含详细代码)
  • ✨本文收录于《python机器学习从入门到高级》专栏,此专栏主要记录如何使用python实现机器学习模型,尽量坚持每周持续更新,欢迎大家订阅!
  • 🌸个人主页:JoJo的数据分析历险记
  • 📝个人介绍:小编大四统计在读,目前保研到统计学top3高校继续攻读统计研究生
  • 💌如果文章对你有帮助,欢迎✌关注、👍点赞、✌收藏、👍订阅专栏

本专栏主要从==代码角度==介绍如何使用python实现机器学习算法,想要了解具体机器学习理论的小伙伴,可以看我的这个专栏:统计学习方法

🍁1.前言

在上一篇文章中,我们介绍了如何对mnist数据集建立一个二分类模型,我们当时解决的问题是给我一张图片,判断是否是数字7,但是我们不仅仅对数字7感兴趣,我们希望给我一张任意的图片,计算机能告诉我这张图片是数字几。这是一个多分类问题。一些算法(如SGD分类器、 随机森林分类器朴素贝叶斯分类器)能处理多个类。其他(如logistic回归)是严格的二元分类器。但是我们可以通过一些策略来实现使用二分类器进行多分类

  • OvR一种方法是对于0-9十个类别,我们对每个类建立一个二分类器。判断是否属于该类,具体实现方法是,给我一张图片,分别使用这十个分类器预测属于该类的概率。选择概率最大的那一类作为预测结果
  • OvO另一种方法是对于0-9十个类别,每一次选两个类别进行比较,比较属于哪一类的概率更大。对于minist数据集,则必须在所有45个分类器进行比较,看看哪个类赢的最多。OvO的主要优点是,每个分类器只需要在训练集的一部分进行训练,即选择需要区分的两个类的数据集。然而,对于大多数二进制分类算法,OvR是首选。

当我们使用二分类器来处理多分类任务时,sklearn会自动选择OvO或者OvR来处理。例如我们以支持向量机(SVM)为例

🍂 2.从二元分类到多分类

# 导入数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
AI 代码解读
import numpy as np
X, y = mnist["data"], mnist["target"]
y = y.astype(np.uint)#更改y数据类型为整数
AI 代码解读
# 将数据划分为测试集和训练集
X_train,X_test,y_train,y_test = X[:6000],X[6000:],y[:6000],y[6000:]
AI 代码解读
from sklearn.svm import SVC
svm_clf = SVC(gamma="auto", random_state=123)
svm_clf.fit(X_train, y_train) # y_train
svm_clf.predict([X[0]])
AI 代码解读
array([5], dtype=uint32)


AI 代码解读

还记得,我们在分类算法上介绍的,第一张图片是数字5,预测正确.
其实SVC默认是采用了OvR策略,我们通过decision_function可以看到每一个样本有10个scores

some_digit_scores = svm_clf.decision_function([X[0]])
some_digit_scores
AI 代码解读
array([[ 1.8249344 ,  8.01830986,  0.81268669,  4.8465137 ,  5.87200033,
         9.29462954,  3.8465137 ,  6.94086295, -0.21310287,  2.83645231]])


AI 代码解读

可以看出,最大的是5

np.argmax(some_digit_scores)
AI 代码解读
5



AI 代码解读
# 查看一共有几类
svm_clf.classes_
AI 代码解读
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint32)


AI 代码解读

注意:训练分类器时,它会将目标类列表按值排序存储在其classes_属性中。在这种情况下,classes_数组中每个类的索引都可以方便地匹配类本身。在本例中,索引5处的类恰好是类5

下面我们使用随机森林模型看看结果

from sklearn.ensemble import RandomForestClassifier
AI 代码解读
rf_clf = RandomForestClassifier(random_state=123)
rf_clf.fit(X_train, y_train) # y_train
rf_clf.predict([X[0]])
AI 代码解读
array([5], dtype=uint32)
AI 代码解读

🍃3.误差分析

首先看看混淆矩阵。需要使用Cross_val_predict函数进行预测,然后调用confusion_matrix()

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict
AI 代码解读

首先这里我将X进行标准化处理

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
AI 代码解读
y_train_pred = cross_val_predict(svm_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
AI 代码解读
array([[576,   0,   4,   2,   3,   2,   2,   0,   3,   0],
       [  0, 649,   9,   1,   3,   1,   0,   3,   4,   1],
       [  4,   5, 531,   7,   8,   2,   3,   9,  11,   1],
       [  0,   5,  28, 542,   2,  14,   1,   9,   5,   2],
       [  0,   2,  14,   0, 578,   1,   2,   6,   0,  20],
       [  3,   4,   9,  16,   7, 450,  10,   7,   3,   5],
       [  3,   2,  23,   0,   2,   7, 567,   2,   2,   0],
       [  2,   8,  14,   0,   7,   0,   0, 593,   0,  27],
       [  4,   7,  15,   8,   2,  15,   6,   2, 488,   4],
       [  4,   2,   9,   7,  13,   2,   0,  25,   3, 536]], dtype=int64)


AI 代码解读

这是有很多类。使用Matplotlibmatshow()函数查看混淆矩阵的图像表示通常更方便:

plt.matshow(conf_mx, cmap=plt.cm.gray)

plt.show()
AI 代码解读


png

这个混淆矩阵看起来不错,因为大多数图像都在主对角线上,这意味着它们被正确分类。5比其他数字略暗,这可能意味着数据集中5的图像较少,或者分类器在5上的性能不如其他数字。现在我们来比较错误率。

row_sums = conf_mx.sum(axis=1, keepdims=True)#计算数量
norm_conf_mx = conf_mx / row_sums#计算错误率的混淆矩阵
AI 代码解读
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
AI 代码解读


png

注意,行代表正确的类,列代表预测的列,可以看出2这个数字这一列很亮,说明有很多其他类被误判为2,但是2这一行却又错判为其他类。通过分析混淆矩阵可以让我们深入了解改进分类器的方法。本例中可以先优化数字2,来减少其他数字对2的错判。例如,您可以尝试为看起来像(但不是)的数字收集更多的训练数据,以便分类器可以学习将它们与真实的2区分开来。或者你可以设计一些新的特性来帮助分类器,例如,编写一个算法来计算每个数字圆圈的数量(例如,8有两个,6有一个,5没有)。或者,你可以对图像进行预处理(例如,使用Scikit ImagePillowOpenCV),以使某些图案(例如闭合环)更加突出。

分析单个错误也是一种很好的方法,可以了解分类器正在做什么,以及它失败的原因,但这更困难、更耗时。例如,让我们绘制数字5和3

def plot_digits(instances, images_per_row=10, **options):
    size = 28
    images_per_row = min(len(instances), images_per_row)#每一行的数字
    n_rows = (len(instances) - 1) // images_per_row + 1

    
    n_empty = n_rows * images_per_row - len(instances)
    padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)
    image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))

    big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,
                                                         images_per_row * size)
    
    plt.imshow(big_image, cmap = mpl.cm.binary, **options)
    plt.axis("off")
AI 代码解读
cl_a, cl_b = 3,5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]

plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)#每一行五个数字
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
AI 代码解读


png

上面一行第二张图是错把3误判为5,第二行第一幅图是错把5判为3的情况

🌷4. 多标签分类

到目前为止,每个分类器都是分给一个类,在某些情况下,我们可能希望一个分类器输出多个类,例如一个人脸识别器;如果它能识别一个图片多个人,那么这就是一个多标签分类器。下面我们照样以mnist数据集为例,
假设此时我们的目标一个是大于7的数,另一个是偶数。下面使用KNN算法为例

from sklearn.neighbors import KNeighborsClassifier
y_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 0)
y_multilabel = np.c_[y_train_large, y_train_odd]

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
AI 代码解读
KNeighborsClassifier()



AI 代码解读
knn_clf.predict([X[0]])
AI 代码解读
array([[False, False]])


AI 代码解读

我们知道第一个数是5,它小于7并且不是偶数,因此两个返回值都是False

本章的介绍到此介绍,如果文章对你有帮助,请多多点赞、收藏、评论、关注支持!!

相关文章
|
2月前
|
时间序列异常检测:MSET-SPRT组合方法的原理和Python代码实现
MSET-SPRT是一种结合多元状态估计技术(MSET)与序贯概率比检验(SPRT)的混合框架,专为高维度、强关联数据流的异常检测设计。MSET通过历史数据建模估计系统预期状态,SPRT基于统计推断判定偏差显著性,二者协同实现精准高效的异常识别。本文以Python为例,展示其在模拟数据中的应用,证明其在工业监控、设备健康管理及网络安全等领域的可靠性与有效性。
615 13
时间序列异常检测:MSET-SPRT组合方法的原理和Python代码实现
揭秘Python的__init__.py:从入门到精通的包管理艺术
__init__.py是Python包管理中的核心文件,既是包的身份标识,也是模块化设计的关键。本文从其历史演进、核心功能(如初始化、模块曝光控制和延迟加载)、高级应用场景(如兼容性适配、类型提示和插件架构)到最佳实践与常见陷阱,全面解析了__init__.py的作用与使用技巧。通过合理设计,开发者可构建优雅高效的包结构,助力Python代码质量提升。
60 10
Python中利用遗传算法探索迷宫出路
本文探讨了如何利用Python和遗传算法解决迷宫问题。迷宫建模通过二维数组实现,0表示通路,1为墙壁,'S'和'E'分别代表起点与终点。遗传算法的核心包括个体编码(路径方向序列)、适应度函数(评估路径有效性)、选择、交叉和变异操作。通过迭代优化,算法逐步生成更优路径,最终找到从起点到终点的最佳解决方案。文末还展示了结果可视化方法及遗传算法的应用前景。
基于 Python 哈希表算法的局域网网络监控工具:实现高效数据管理的核心技术
在当下数字化办公的环境中,局域网网络监控工具已成为保障企业网络安全、确保其高效运行的核心手段。此类工具通过对网络数据的收集、分析与管理,赋予企业实时洞察网络活动的能力。而在其运行机制背后,数据结构与算法发挥着关键作用。本文聚焦于 PHP 语言中的哈希表算法,深入探究其在局域网网络监控工具中的应用方式及所具备的优势。
74 7
[oeasy]python083_类_对象_成员方法_method_函数_function_isinstance
本文介绍了Python中类、对象、成员方法及函数的概念。通过超市商品分类的例子,形象地解释了“类型”的概念,如整型(int)和字符串(str)是两种不同的数据类型。整型对象支持数字求和,字符串对象支持拼接。使用`isinstance`函数可以判断对象是否属于特定类型,例如判断变量是否为整型。此外,还探讨了面向对象编程(OOP)与面向过程编程的区别,并简要介绍了`type`和`help`函数的用法。最后总结指出,不同类型的对象有不同的运算和方法,如字符串有`find`和`index`方法,而整型没有。更多内容可参考文末提供的蓝桥、GitHub和Gitee链接。
60 11
【重磅发布】AllData数据中台核心功能:机器学习算法平台
杭州奥零数据科技有限公司成立于2023年,专注于数据中台业务,维护开源项目AllData并提供商业版解决方案。AllData提供数据集成、存储、开发、治理及BI展示等一站式服务,支持AI大模型应用,助力企业高效利用数据价值。
员工电脑监控场景下 Python 红黑树算法的深度解析
在当代企业管理范式中,员工电脑监控业已成为一种广泛采用的策略性手段,其核心目标在于维护企业信息安全、提升工作效能并确保合规性。借助对员工电脑操作的实时监测机制,企业能够敏锐洞察潜在风险,诸如数据泄露、恶意软件侵袭等威胁。而员工电脑监控系统的高效运作,高度依赖于底层的数据结构与算法架构。本文旨在深入探究红黑树(Red - Black Tree)这一数据结构在员工电脑监控领域的应用,并通过 Python 代码实例详尽阐释其实现机制。
49 7
Python入门修炼:开启你在大数据世界的第一个脚本
Python入门修炼:开启你在大数据世界的第一个脚本
77 6
Python创意爱心代码大全:从入门到高级的7种实现方式
本文分享了7种用Python实现爱心效果的方法,从简单的字符画到复杂的3D动画,涵盖多种技术和库。内容包括:基础字符爱心(一行代码实现)、Turtle动态绘图、Matplotlib数学函数绘图、3D旋转爱心、Pygame跳动动画、ASCII艺术终端显示以及Tkinter交互式GUI应用。每种方法各具特色,适合不同技术水平的读者学习和实践,是表达创意与心意的绝佳工具。
611 0
基于 Python 迪杰斯特拉算法的局域网计算机监控技术探究
信息技术高速演进的当下,局域网计算机监控对于保障企业网络安全、优化资源配置以及提升整体运行效能具有关键意义。通过实时监测网络状态、追踪计算机活动,企业得以及时察觉潜在风险并采取相应举措。在这一复杂的监控体系背后,数据结构与算法发挥着不可或缺的作用。本文将聚焦于迪杰斯特拉(Dijkstra)算法,深入探究其在局域网计算机监控中的应用,并借助 Python 代码示例予以详细阐释。
58 6

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等