1 聚类算法说明
1.1 1. 引入
聚类算法一种无监督的(unsupervised)经典机器学习算法。先从感性上认识一下什么是聚类。聚 类的核心思想就是将具有相似特征的事物给 “聚” 在一起,也就是说 “聚” 是一个动词。俗话说 人以群分,物以类聚说得就是这个道理。如图所示为三种类型的数据样本,其中每种颜色都表示一个类别。而聚类算法的目的就是就是将各 个类别的样本点分开,也就是将同一种类别的样本点聚在一起。聚类的目的就是将相似度较高的样 本点放到一个簇中。
1.2 2.Kmeans 聚类
聚类算法中最常用 Kmeans 聚类算法。
如图左右所示为同一数据集的两种不同聚类结果,其中同种颜色表示聚类后被划分到了同一个簇 中,黑色圆点为聚类后的簇中心。从可视化结果来看,左图的聚类结果跟定好于右图的聚类结果。也就是说,我们可以通过最小化目标函数 d=d1+d2+d3+...+d10 来得到最优解。
1.2.1 原理
Kmeans 聚类算法也被称为 Kmeans 均值聚类,其主要原理为:① 首先随机选择 Kmeans 个样本点作为 Kmeans 个簇的初始簇中心;② 然后计算每个样本点与这个 Kmeans 个簇中心的相似度大小,并将 该样本点划分到与之相似度 最大的簇中心所对应的簇中; ③ 根据现有的簇中样本,重新计算每个簇的簇中心; ④ 循环迭代步骤 ②③,直到目标函数收敛,即簇中心不再发生变化。
如图所示为一个聚类过程中的示例,左上角为正确标签下的样本可视化结果(每种颜色表示一个类 别),其中三个黑色圆点为随机初始化的三个簇中心;当 iter=1 时表算法第一次迭代后的结果,可 以看到此时的算法将左边的两个簇都划分到了一个簇中,而右下角的一个簇被分成了两个簇;然后 依次进行反复迭代,当第四次迭代完成后,可以发现三个簇中心基本上已经位于三个簇中了,被错 分的样本也在逐渐减少;当进行完第五次迭代后,可以发现基本上已经完成了对整个样本的聚类处 理,只需要再迭代几次即可收敛。以上就是 Kmeans 聚类算法在整个聚类过程中的变化情况。
1.2.2 k 值选择
经过上面的介绍,我们已经知道了 Kmeans 聚类算法的基本原理。但现在有个问题就是,我们怎么 来确定聚类的 Kmeans 值呢?也就是说我们需要将数据集聚成多少个簇?如果已经很明确数据集 中存在多少个簇,那么就直接指定 Kmeans 值即可;如果并不知道数据集中有多少个簇,则需要结 合另外一些办法来进行选取,例如看轮廓系数、结果的稳定性等等。预设不同的簇数分类结果原始数据:
分类结果:
结论:预设 3 簇
结论:预设 4 簇
结论:预设 5 簇结论:预设 4 簇的时候其平均轮廓系数最高,所以分 4 簇是最优的,与数据集相 匹配。轮廓系数定义为:
1.2.3 sklearn 建模
安装:conda install scikit-learn
导入:from sklearn.cluster import KMeans
: # 目标:通过客户信息为客户分类,确定哪些客户为有价值客户 import numpy as np import pandas as pd from sklearn.cluster import KMeans df=pd.read_csv('客户.csv',encoding='gb18030') df
[8]: 客户年龄 平均每次消费金额 平均消费次数 0 23 317 10 1 22 147 13 2 24 172 17 3 27 194 67 4 37 789 35 5 25 190 1 6 29 281 10 7 27 142 12 8 28 186 8 9 23 226 1 10 22 287 32 11 32 499 3 12 25 181 90 13 26 172 1 14 24 190 16 15 27 271 31 16 40 382 25
x=df[['平均消费次数','平均每次消费金额']].values print(x) kms=KMeans(n_clusters=3) y=kms.fit_predict(x) print(y) df['价值']=y print(df)
[[ 10 317] [ 13 147] [ 17 172] [ 67 194] [ 35 789] [ 1 190] [ 10 281] [ 12 142] [ 8 186] [ 1 226] [ 32 287] [ 3 499] [ 90 181] [ 1 172] [ 16 190] [ 31 271] [ 25 382]] [2 0 0 0 1 0 2 0 0 0 2 2 0 0 0 2 2] 客户年龄 平均每次消费金额 平均消费次数 价值 0 23 317 10 2 1 22 147 13 0 2 24 172 17 0 3 27 194 67 0 4 37 789 35 1 5 25 190 1 0 6 29 281 10 2 7 27 142 12 0 8 28 186 8 0 9 23 226 1 0 10 22 287 32 2 11 32 499 3 2 12 25 181 90 0 13 26 172 1 0 14 24 190 16 0 15 27 271 31 2 16 40 382 25 2
# 绘图 import matplotlib.pyplot as plt plt.rcParams['font.sans-serif']=['SimHei'] plt.rcParams['axes.unicode_minus']=False plt.figure(figsize=(10,10)) plt.scatter(x[:,0],x[:,1],c=y,s=20) plt.title('三分类') plt.show()