数据分析入门系列教程-K-Means原理

简介: 数据分析入门系列教程-K-Means原理

今天我们来学习 K-Means 算法,这是一种非监督学习。所谓的监督学习和非监督学习的区别就是样本中是否存在标签,对于有标签的样本做分析就是监督学习,而对没有标签的样本做分析就属于非监督学习。

K-Means 解决的是聚类的问题,就是把样本根据某些特征,按照某些中心点,聚类在一起,从而达到分类的效果。K 代表的是 K 类,Means 代表的是中心,所以该算法的本质其实就是确定 K 类的中心点,当我们找到中心点后,也就完成了聚类。

聚类的应用场景是非常多的,比如给用户群分类,对用户行为划分等待,特别是在没有标签的情况下,只能只用聚类的方式做分析。


K-Means原理


还记得我们在 KNN 算法中使用的 offer 案例吗


image.png

现在我们把这个例子稍微修改下

我们把例子中每个人是否能够获得 offer 这个信息去掉,即我们现在拥有的信息只是每个人的工作经验和当前工资,那么对于小 K 来说,该怎么判别他是否能够获得 offer 呢


K-Means 模型假设

将某一些数据分为不同的类别,在相同的类别中数据之间的距离应该都很近,也就是说离得越近的数据应该越相似,再进一步说明,数据之间的相似度与它们之间的欧式距离成反比。这就是 k-means 模型的假设。

有了这个假设,我们对将数据分为不同的类别的算法就更明确了,尽可能将离得近的数据划分为一个类别。那么现在的问题就转变为一个求中心和距离的问题了。

image.png

K-Means 算法

其实 K-Means 算法也是非常简单的一种算法,我们可以把它概括为两步

  1. 初始化

随机选择 K 个点,作为初始中心,每个点代表一个 group

  1. 迭代更新

1)计算每个点到所有中心点的距离,把最近的距离记录下来并把 group 赋给当前的点

2)针对每一个 group 里的点,计算其平均并把该值作为 group 的新的中心点

3)重复上面的两步,直到 group 不再变化为止,或者到达设置的最大迭代次数。

下面我们使用一组图片来具体体会下这个算法的过程

image.png

图片 a:有初始状态如图中的样本,现在随机选择两个初始点

图片 b:依次计算所有样本点到两个初始点的距离,根据距离的大小,把样本点分配到不同的中心点上。

图片 c:再把左上部分的点计算平均值,得到左上部分的新的中心点,右下类同,可以得到新的中心点,这就是一个迭代后的聚类情况

图片 d:开始第二次迭代,再把所有的点依次计算到中心点的距离,可以得到新的聚类情况

图片 e:再次计算新的聚类中点的平均值,得出新的中心点

图片 f:重复 d 的过程,中心点不再变化

图片 i:聚类完成


K-Means 算法实现


首先我们先导入下准备好的规整的数据

from copy import deepcopy
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
k_means_data = pd.read_csv('k-meansdata.csv')
print(k_means_data.shape)
print(k_means_data.head())
>>>
(3000, 2)
          V1         V2
0   2.072345  -3.241693
1  17.936710  15.784810
2   1.083576   7.319176
3  11.120670  14.406780
4  23.711550   2.557729

下面再把我们准备的数据可视化出来,整体观察下

data1 = k_means_data['V1'].values
data2 = k_means_data['V2'].values
X = np.array(list(zip(data1, data2)))
plt.scatter(X[:,0], X[:,1], s=6)


image.png

可以清楚的看出,当前的数据可以大致聚类成3类,那么下面我们就手写一个 k-Means 算法来完成这个聚类的过程

def self_kmeans(data, k):
    m, n = data.shape
    results = np.empty(m)
    cores = np.copy(data[np.random.randint(0, m, size=k)])
    coreChange = True
    while coreChange:
        for i in range(m):
            distance = np.linalg.norm(data[i] - cores, axis=1)
            result = np.argmin(distance)
            results[i] = result
        cores_old = deepcopy(cores)
        for i in range(k):
            points = [data[j] for j in range(m) if results[j] == i]
            cores[i] = np.mean(points, axis=0)        if (cores_old == cores).all():
            return results, cores

我们来逐行看下代码

我们定义了一个 self_kmeans 的函数,接收两个参数,第一个是样本数据,是矩阵的形式;第二个是需要聚类的数量 k

接着我们通过 shape 属性获取到矩阵数据的行和列的数值,再定义聚类结果 results,用于存储最终的聚类结果。

对于中心点 cores,是通过 numpy 的 randint 函数来做随即处理的。

接下来就就如 while 循环了,首先遍历所有的样本( for i in range(m)),并计算样本与中心点的距离,通过 norm 范数来计算。

所谓的范数,是线性代数领域的概念,是具有长度概念的函数,而向量二范数就是向量的长度。

接下来,通过计算距离,把距离最短的归为一类,再把当前的中心点保存下来。

下面更新中心点,遍历 K 值,求出每个类别里所有点的均值,作为新的中心点。

最后的 if 判断代码是用来退出 while 循环的,如果中心点不再改变,再退出函数,返回类别和中心点。

下面我们运行下函数,看看结果如何

result, core = self_kmeans(X, 3)
colors = ['r', 'g', 'b', 'y', 'c', 'm']
fig, ax = plt.subplots()
for i in range(3):
        points = np.array([X[j] for j in range(len(X)) if result[j] == i])
        ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i])
ax.scatter(cores[:, 0], cores[:, 1], marker='*', s=200, c='#050505')


image.png

可以看到,还是比较好的聚类区分出了三类数据点,当然这里我们准备的数据样本比较简单,如果我们使用随机的样本点呢,聚类结果会是怎么样的呢

我们把样本点 X 替换一下

X = np.random.random((200, 2))*10


image.png

对于这种完全随机的样本点,我们再带入 K-Means 函数,查看结果

第一次运行


image.png

第二次运行


image.png

第三次运行


image.png

可以明显的看出,每次运行的结果,都会有微小的差别,这说明不同的初始数据,最后产生的聚类结果也不尽相同。


k-means的几个问题


不同的初始化数据,是否会产生不同的结果:

就如同我们上面演示所示,不同的初始化数据,最终的聚类结果也不相同。

那么该如何更好的选择初始化数据呢,一般可以采用 k-means++ 的方式来处理

k-means++

它的工作流程大致如下:

1、从输入的数据点集合中随机选择一个点作为第一个聚类中心

2、对于数据集中的每一个点 x,计算它与最近聚类中心(指已选择的聚类中心)的距离 D(x)

3、选择一个新的数据点作为新的聚类中心,选择的原则是:D(x) 较大的点,被选取作为聚类中心的概率较大

4、重复2和3直到 k 个聚类中心被选出来

5、利用这 k 个初始的聚类中心来运行标准的 k-means 算法

我们可以看到,通过 k-means++ 的逻辑,可以很好的保证我们选择的初始数据点不会过于集中,这样就能保证更好的聚类效果,也能减少计算量。

K 值该如何选择:

对于 K 值的选择,主要使用手肘法

手肘法

手肘法的评判指标是误差平方和

image.png

其核心思想是随着聚类 K 的增大,样本划分会更加精细,每个类别的聚合程度会逐渐提高,那么误差平方和就会逐渐变小。当 K 小于真实的聚类数时,SS 的下降幅度会很大,而当 K 大于真实的聚类数时,SS 的下降幅度会变的很小,如果画成一条曲线,会类似于一个手肘的形状,故而称为手肘法。

类似图表如下

image.png

可以看到从 k = 4 开始曲线的变化逐渐变缓,所以 k = 4 应该是最佳的 k 值

对于手肘法和 k-means++ 的具体使用,我们会在下一节详细讲解。


优缺点


优点

算法简单,容易理解,而且适用于高维数据

缺点

对离群点、噪声点和孤立点很敏感,初始点的选择好坏,决定了聚类结果的好坏


总结


今天讲解了 K-Means 算法的原理和算法模型,K-Means 的两个核心点就是设置初始数据和迭代更新中心点。

接下来我们又手动实现了一个简单的 K-Means 算法,通过不同的数据验证,我们还发现不同的初始值,不会产生不同的聚类结果,所以初始值的选择是非常重要的。

对于初始值的选择,我们一般采用 k-means++ 的方式来处理,尽量使得 K 个初始点之间的距离最大;而对于 K 值的选择,可以使用手肘法,通过观察曲线的拐点来选择最佳的 K 值。

本文涉及的代码和数据可以在这里下载

https://github.com/zhouwei713/DataAnalyse/tree/master/K-Means


微信图片_20220520181940.png


练习题


今天讲的 K-Means 和前面学习的 K-NN 有什么异同呢?

相关文章
|
7天前
|
SQL 数据挖掘 Python
R中单细胞RNA-seq数据分析教程 (1)
R中单细胞RNA-seq数据分析教程 (1)
27 5
R中单细胞RNA-seq数据分析教程 (1)
|
1月前
|
数据可视化 数据挖掘 大数据
Python 数据分析入门:从零开始处理数据集
Python 数据分析入门:从零开始处理数据集
|
22天前
|
数据采集 机器学习/深度学习 数据可视化
深入浅出:用Python进行数据分析的入门指南
【10月更文挑战第21天】 在信息爆炸的时代,掌握数据分析技能就像拥有一把钥匙,能够解锁隐藏在庞大数据集背后的秘密。本文将引导你通过Python语言,学习如何从零开始进行数据分析。我们将一起探索数据的收集、处理、分析和可视化等步骤,并最终学会如何利用数据讲故事。无论你是编程新手还是希望提升数据分析能力的专业人士,这篇文章都将为你提供一条清晰的学习路径。
|
1月前
|
数据挖掘 索引 Python
Python数据分析篇--NumPy--入门
Python数据分析篇--NumPy--入门
33 0
|
1月前
|
机器学习/深度学习 数据采集 数据可视化
Python中的简单数据分析:入门指南
【10月更文挑战第2天】Python中的简单数据分析:入门指南
33 0
|
3月前
|
数据采集 数据可视化 数据挖掘
数据分析大神养成记:Python+Pandas+Matplotlib助你飞跃!
在数字化时代,数据分析至关重要,而Python凭借其强大的数据处理能力和丰富的库支持,已成为该领域的首选工具。Python作为基石,提供简洁语法和全面功能,适用于从数据预处理到高级分析的各种任务。Pandas库则像是神兵利器,其DataFrame结构让表格型数据的处理变得简单高效,支持数据的增删改查及复杂变换。配合Matplotlib这一数据可视化的魔法棒,能以直观图表展现数据分析结果。掌握这三大神器,你也能成为数据分析领域的高手!
79 2
|
3月前
|
机器学习/深度学习 数据采集 数据可视化
基于爬虫和机器学习的招聘数据分析与可视化系统,python django框架,前端bootstrap,机器学习有八种带有可视化大屏和后台
本文介绍了一个基于Python Django框架和Bootstrap前端技术,集成了机器学习算法和数据可视化的招聘数据分析与可视化系统,该系统通过爬虫技术获取职位信息,并使用多种机器学习模型进行薪资预测、职位匹配和趋势分析,提供了一个直观的可视化大屏和后台管理系统,以优化招聘策略并提升决策质量。
178 4
|
3月前
|
机器学习/深度学习 算法 数据挖掘
2023 年第二届钉钉杯大学生大数据挑战赛初赛 初赛 A:智能手机用户监测数据分析 问题二分类与回归问题Python代码分析
本文介绍了2023年第二届钉钉杯大学生大数据挑战赛初赛A题的Python代码分析,涉及智能手机用户监测数据分析中的聚类分析和APP使用情况的分类与回归问题。
84 0
2023 年第二届钉钉杯大学生大数据挑战赛初赛 初赛 A:智能手机用户监测数据分析 问题二分类与回归问题Python代码分析
|
10天前
|
SQL 数据挖掘 Python
数据分析编程:SQL,Python or SPL?
数据分析编程用什么,SQL、python or SPL?话不多说,直接上代码,对比明显,明眼人一看就明了:本案例涵盖五个数据分析任务:1) 计算用户会话次数;2) 球员连续得分分析;3) 连续三天活跃用户数统计;4) 新用户次日留存率计算;5) 股价涨跌幅分析。每个任务基于相应数据表进行处理和计算。
|
1月前
|
机器学习/深度学习 数据采集 数据可视化
数据分析之旅:用Python探索世界
数据分析之旅:用Python探索世界
28 2