SparkML机器学习之聚类(K-Means、GMM、LDA)

简介:

聚类的概念

聚类就是对大量未知标注(无监督)的数据集,按照数据之间的相似度,将N个对象的数据集划分为K个划分(K个簇),使类别内的数据相似度较大,而类别间的数据相似较小。比如用户画像就是一种很常见的聚类算法的应用场景,基于用户行为特征或者元数据将用户分成不同的类。

常见聚类以及原理

K-means算法

也被称为k-均值,是一种最广泛使用的聚类算法,也是其他聚类算法的基础。来看下它的原理:

既然要划分为k个簇,因此算法首先随机的选择了k个对象,每个对象初始的代表了一个簇的中心。其他的对象就去计算它们与这k个中心的距离(这里距离就是相似度),离哪个最近就把它们丢给哪个簇。第一轮聚类结束后,重新计算每个簇的平均值,将N中全部元素按照新的中心重新聚类。这个过程不断的反复,使得每一个改进之后的划分都比之前效果好,直到准则函数收敛(聚类结果不再变化)。
image

举个例子,10个男生我要分为2个类,我先随机选择2个男生a b,那么对于c,我分别计算他与a b 的相似度,假设和a最相似,那么我丢到a这个类里,经过第一轮,得到(1类){a,c,e,f,i}(2类) {b,d,g,h,j},现在我重新计算1类和2类的均值,假设为1.5和5.1,那么对这10个男生再判断它们与1.5 5.1的距离,和哪个相近就去哪边。第二轮迭代使得聚类更加精确。反复这样,直到某一次的迭代我发现不再有变化,那么就ok了。

但是K-means有些缺点,其一是受初始值影响较大。下面这张图很好的解释了这个缺点,人眼一看就能看出来,如果是分为4个聚类,应该这么分为左边那样的,如果用K-means结果会是右边这样,明显不对,所以说受初始值的影响会比较大。

image

因为这个缺陷,所以有了Bisecting k-means(二分K均值)

Bisecting k-means(二分K均值)

主要是为了改进k-means算法随机选择初始中心的随机性造成聚类结果不确定性的问题,而Bisecting k-means算法受随机选择初始中心的影响比较小。
先将所有点作为一个簇,然后将该簇一分为二。之后选择其中一个簇【具有最大SSE值的一个簇】继续进行划分,二分这个簇以后得到的2个子簇,选择2个子簇的总SSE最小的划分方法,这样能够保证每次二分得到的2个簇是比较优的(也可能是最优的)。
SSE(Sum of Squared Error),也就是误差平方和,它计算的是拟合数据和原始数据对应点的误差的平方和,它是用来度量聚类效果的一个指标。SSE越接近于0,说明模型选择和拟合更好,数据预测也越成功。

上面讲的都是硬聚类,硬聚类即一定是属于某一个类,比如我有2个簇A和B,那么所有的对象要不属于A要不就属于B,不可能会有第三种可能。而软聚类,使用概率的方式,一个对象可以是60%属于A,40% 属于B,即不完全属于同一个分布,而是以不同的概率分属于不同的分布GMM(高斯混合模型)就是一种软聚类。

GMM(高斯混合模型)

它和K-Means的区别是,K-Means是算出每个数据点所属的簇,而GMM是计算出这些数据点分配到各个类别的概率
GMM算法步骤如下:
1.猜测有 K 个类别、即有K个高斯分布。
2.对每一个高斯分布赋均值 μ 和方差 Σ 。
3.对每一个样本,计算其在各个高斯分布下的概率。
image.png

4.每一个样本对某高斯分布的贡献可以由其下的概率表示。并把该样本对该高斯分布的贡献作为权重来计算加权的均值和方差以替代其原本的均值和方差。
5.重复3~4直到每一个高斯分布的均值和方差收敛。

SparkML聚类

SparkML中主要聚类有以下几种:

  • K-means
  • Latent Dirichlet allocation (LDA)
  • Bisecting k-means
  • Gaussian Mixture Model (GMM)

KMeans

package ml.test
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.sql.SparkSession
/**
  * Created by liuyanling on 2018/3/24   
  */
object KMeansDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("kmeans_data.txt")
    //setK设置要分为几个类 setSeed设置随机种子
    val kmeans = new KMeans().setK(3).setSeed(1L)
    //聚类模型
    val model = kmeans.fit(df)
    // 预测 即分配聚类中心
    model.transform(df).show(false)
    //聚类中心
    model.clusterCenters.foreach(println)
    // SSE误差平方和
    println("SSE:"+model.computeCost(df))
  }
}

输出结果为:

+-----+-------------------------+----------+
|label|features                 |prediction|
+-----+-------------------------+----------+
|0.0  |(3,[],[])                |1         |
|1.0  |(3,[0,1,2],[0.1,0.1,0.1])|1         |
|2.0  |(3,[0,1,2],[0.2,0.2,0.2])|2         |
|3.0  |(3,[0,1,2],[9.0,9.0,9.0])|0         |
|4.0  |(3,[0,1,2],[9.1,9.1,9.1])|0         |
|5.0  |(3,[0,1,2],[9.2,9.2,9.2])|0         |
+-----+-------------------------+----------+

[9.1,9.1,9.1]
[0.05,0.05,0.05]
[0.2,0.2,0.2]

SSE:0.07499999999994544

附:kmeans_data.txt

0 1:0.0 2:0.0 3:0.0
1 1:0.1 2:0.1 3:0.1
2 1:0.2 2:0.2 3:0.2
3 1:9.0 2:9.0 3:9.0
4 1:9.1 2:9.1 3:9.1
5 1:9.2 2:9.2 3:9.2

BisectingKMeans二分K均值

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("kmeans_data.txt")
    val kmeans = new BisectingKMeans().setK(3).setSeed(1)
    //聚类模型
    val model = kmeans.fit(df)
    // 预测 即分配聚类中心
    model.transform(df).show(false)
    //聚类中心
    model.clusterCenters.foreach(println)
  }

输出结果为:

+-----+-------------------------+----------+
|label|features                 |prediction|
+-----+-------------------------+----------+
|0.0  |(3,[],[])                |0         |
|1.0  |(3,[0,1,2],[0.1,0.1,0.1])|0         |
|2.0  |(3,[0,1,2],[0.2,0.2,0.2])|0         |
|3.0  |(3,[0,1,2],[9.0,9.0,9.0])|1         |
|4.0  |(3,[0,1,2],[9.1,9.1,9.1])|1         |
|5.0  |(3,[0,1,2],[9.2,9.2,9.2])|2         |
+-----+-------------------------+----------+

[0.1,0.1,0.1]
[9.05,9.05,9.05]
[9.2,9.2,9.2]

可以发现,使用kmeans和BisectingKMeans,聚类结果一般是不一样的。

GMM高斯混合模型

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("kmeans_data.txt")
    val gmm = new GaussianMixture().setK(3).setSeed(0)
    val model = gmm.fit(df)
    for (i <- 0 until model.getK) {
      //weight是各组成成分的权重
      //cov是样本协方差矩阵
      //mean是均值
      println(s"Gaussian $i:\nweight=${model.weights(i)}\n" +
        s"mu=${model.gaussians(i).mean}\nsigma=\n${model.gaussians(i).cov}\n")
    }
    //也可以使用这种方式输出:model.gaussiansDF.show(false)
  }

输出结果为:

Gaussian 0:
weight=0.28757327891867634
mu=[0.09999633894391767,0.09999633894391767,0.09999633894391767]
sigma=
0.006666589009587926  0.006666589009587926  0.006666589009587926  
0.006666589009587926  0.006666589009587926  0.006666589009587926  
0.006666589009587926  0.006666589009587926  0.006666589009587926  

Gaussian 1:
weight=0.2124267210813245
mu=[0.1000049561651758,0.1000049561651758,0.1000049561651758]
sigma=
0.006666771753107945  0.006666771753107945  0.006666771753107945  
0.006666771753107945  0.006666771753107945  0.006666771753107945  
0.006666771753107945  0.006666771753107945  0.006666771753107945  

Gaussian 2:
weight=0.49999999999999917
mu=[9.099999999999984,9.099999999999984,9.099999999999984]
sigma=
0.006666666666831146  0.006666666666831146  0.006666666666831146  
0.006666666666831146  0.006666666666831146  0.006666666666831146  
0.006666666666831146  0.006666666666831146  0.006666666666831146  

LDA主题模型

LDA是一个三层贝叶斯概率模型,包含词、主题和文档三层结构。
LDA可以用来生成一篇文档,生成时,每个词都是通过“以一定概率选择了某个主题,并从这个主题中以一定概率选择某个词语”,这样反复进行,就可以生成一篇文档;反过来,LDA又是一种非监督机器学习技术,可以识别出大规模文档集或语料库中的主题。

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("lda_data.txt")
    //训练LDA模型、这一堆文本里我需要有10个topic
    val lda = new LDA().setK(10).setMaxIter(10)
    val model = lda.fit(df)
    //计算整个语料库对数可能性的下限
    val ll = model.logLikelihood(df)
    //Perplexity(困惑度)的上限,用于评估LDA主题模型好坏,判断改进的参数或者算法的建模能力。
    //通过观察Perplexity指标随topic个数的变化能够帮助我们选择合适的topic个数值,越低越好
    val lp = model.logPerplexity(df)
    println(s"Likelihood: $ll")
    println(s"Perplexity: $lp")
    // 输出主题,参数4是指定每个topic返回的词数量(已经按照权重降序排列)
    // 即这四个词与主题最相关,对主题的贡献程度最大,贡献程度分别为。。。。
    val topics = model.describeTopics(4)
    println("输出与主题最相关的四个词以及它们的权重(对主题的贡献程度):")
    topics.show(false)
    // Shows the result
    val transformed = model.transform(df)
    transformed.show(false)
  }

输出结果:

Likelihood: -841.0539703557463
Perplexity: 3.235555050621537
输出与主题最相关的四个词以及它们的权重(对主题的贡献程度):
+-----+-------------+------------------------------------------------------------------------------------+
|topic|termIndices  |termWeights                                                                         |
+-----+-------------+------------------------------------------------------------------------------------+
|0    |[2, 5, 7, 9] |[0.10606440859601529, 0.10570106168080187, 0.10430389617431601, 0.09677466095389772]|
|1    |[1, 6, 2, 5] |[0.10185076997491191, 0.09816928141878781, 0.09632454354036399, 0.09533709162736335]|
|2    |[10, 6, 9, 1]|[0.21830191650133673, 0.13864436129454022, 0.130631061595757, 0.12280252973166123]  |
|3    |[0, 4, 8, 5] |[0.10270701955806716, 0.098428481533562, 0.09815661242071609, 0.09625859107744991]  |
|4    |[9, 6, 4, 0] |[0.10452964428601186, 0.10414908178146716, 0.10103987045642693, 0.09653933325158909]|
|5    |[1, 10, 0, 6]|[0.10214945376665654, 0.10129060012341293, 0.09513643667808531, 0.09484723303591766]|
|6    |[3, 7, 4, 5] |[0.11638316687887292, 0.09901763170594163, 0.09795372072037434, 0.09538797685003378]|
|7    |[4, 0, 2, 7] |[0.1085545365386738, 0.10334275138802261, 0.10034943368678806, 0.09586142922666488] |
|8    |[0, 7, 8, 9] |[0.11008008210214115, 0.09919723498734867, 0.09810902425212233, 0.09598429155133426]|
|9    |[9, 6, 8, 7] |[0.10106110089499898, 0.10013295826865794, 0.09769277851352344, 0.09637374368101154]|
+-----+-------------+------------------------------------------------------------------------------------+

+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|features                                                       |topicDistribution                                                                                                                                                                                                       |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.7020492743453446,0.004825993770645003,0.2593430279085318,0.004825963614085213,0.0048259356594697244,0.004825986805023069,0.004825959562158929,0.004826029371171158,0.004825898616321139,0.004825930347249381]        |
|1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |[0.008051227952034095,0.00805068078020596,0.9275409153537755,0.008051404074962873,0.008050494944341442,0.008050970211048909,0.008051192103744834,0.008051174355007377,0.008051123052877147,0.008050817172001883]        |
|2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |[0.004196160998853775,0.004196351498154125,0.9622355518772447,0.0041960987674483745,0.004195948652190585,0.004195985718066058,0.004196035128639296,0.0041960011176528375,0.00419585892325802,0.004196007318492326]      |
|3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |[0.0037108249083737366,0.003710870004756753,0.9666019491390152,0.0037108898639785916,0.003710932316515782,0.0037109398110594777,0.003710889692814606,0.0037108416716129995,0.0037109726803076204,0.0037108899115651296] |
|4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |[0.0040206150698117475,0.004020657148286404,0.9638139184735203,0.004020683139979867,0.004020734091778302,0.004020674583024828,0.004020782022237563,0.0040206568444815,0.0040206713675701175,0.00402060725930927]        |
|5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.0037112930006027146,0.0037112713017637172,0.3775560310450831,0.5927537995162975,0.003711281238052303,0.003711254221319455,0.0037112951793086854,0.0037112986549696415,0.0037112620387514264,0.003711213803851488]    |
|6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |[0.0038593994507872802,0.0038594504955414776,0.9652648960492173,0.0038594706691946765,0.0038594932727853293,0.0038594785604599284,0.0038594622091066354,0.003859418239947566,0.0038594708043909586,0.003859460248568964]|
|7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.0043866961466148225,0.004386703747213623,0.9605195811989302,0.004386713724658389,0.004386726334830402,0.004386684736152265,0.00438695465308044,0.004386652225319148,0.0043866443293358506,0.0043866429038647205]     |
|8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |[0.00438677418165635,0.004386828345825393,0.004928678499127466,0.9599768874208335,0.004386777636200246,0.0043868515295607344,0.004386859902729271,0.004386803952404989,0.004386800659948354,0.004386737871713616]       |
|9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |[0.00332681244951176,0.003326822533688302,0.9700586410381743,0.0033268123008501223,0.0033268596356941897,0.003326828624565359,0.0033267707086431955,0.003326842280807512,0.0033268054589628416,0.003326804969102545]    |
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |[0.00419586618191812,0.004195849140887235,0.9622374586787058,0.0041958390363593,0.00419580989169427,0.0041958000694942875,0.004196080417873127,0.004195750742202915,0.004195747737696082,0.004195798103168887]          |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |[0.004826026955676383,0.0048259715597386765,0.005421055835370809,0.9559710760555791,0.004825952242828419,0.004825974033041617,0.0048259715984279505,0.004826101409594881,0.004825983107633105,0.004825887202109255]     |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

附:lda_data.txt

0 1:1 2:2 3:6 4:0 5:2 6:3 7:1 8:1 9:0 10:0 11:3
1 1:1 2:3 3:0 4:1 5:3 6:0 7:0 8:2 9:0 10:0 11:1
2 1:1 2:4 3:1 4:0 5:0 6:4 7:9 8:0 9:1 10:2 11:0
3 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:3 11:9
4 1:3 2:1 3:1 4:9 5:3 6:0 7:2 8:0 9:0 10:1 11:3
5 1:4 2:2 3:0 4:3 5:4 6:5 7:1 8:1 9:1 10:4 11:0
6 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:2 11:9
7 1:1 2:1 3:1 4:9 5:2 6:1 7:2 8:0 9:0 10:1 11:3
8 1:4 2:4 3:0 4:3 5:4 6:2 7:1 8:3 9:0 10:0 11:0
9 1:2 2:8 3:2 4:0 5:3 6:0 7:2 8:0 9:2 10:7 11:2
10 1:1 2:1 3:1 4:9 5:0 6:2 7:2 8:0 9:0 10:3 11:3
11 1:4 2:1 3:0 4:0 5:4 6:5 7:1 8:3 9:0 10:1 11:0
相关文章
|
2月前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
117 4
|
3月前
|
机器学习/深度学习 算法 数据可视化
机器学习的核心功能:分类、回归、聚类与降维
机器学习领域的基本功能类型通常按照学习模式、预测目标和算法适用性来分类。这些类型包括监督学习、无监督学习、半监督学习和强化学习。
69 0
|
5月前
|
机器学习/深度学习 算法 数据中心
【机器学习】面试问答:PCA算法介绍?PCA算法过程?PCA为什么要中心化处理?PCA为什么要做正交变化?PCA与线性判别分析LDA降维的区别?
本文介绍了主成分分析(PCA)算法,包括PCA的基本概念、算法过程、中心化处理的必要性、正交变换的目的,以及PCA与线性判别分析(LDA)在降维上的区别。
109 4
|
5月前
|
机器学习/深度学习 数据采集 算法
【机器学习】K-Means聚类的执行过程?优缺点?有哪些改进的模型?
K-Means聚类的执行过程、优缺点,以及改进模型,包括K-Means++和ISODATA算法,旨在解决传统K-Means算法在确定初始K值、收敛到局部最优和对噪声敏感等问题上的局限性。
73 2
|
5月前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】聚类算法中的距离度量有哪些及公式表示?
聚类算法中常用的距离度量方法及其数学表达式,包括欧式距离、曼哈顿距离、切比雪夫距离、闵可夫斯基距离、余弦相似度等多种距离和相似度计算方式。
467 1
|
5月前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】Python详细实现基于欧式Euclidean、切比雪夫Chebyshew、曼哈顿Manhattan距离的Kmeans聚类
文章详细实现了基于不同距离度量(欧氏、切比雪夫、曼哈顿)的Kmeans聚类算法,并提供了Python代码,展示了使用曼哈顿距离计算距离矩阵并输出k=3时的聚类结果和轮廓系数评价指标。
112 1
|
6月前
|
机器学习/深度学习 算法 数据可视化
Fisher模型在统计学和机器学习领域通常指的是Fisher线性判别分析(Fisher's Linear Discriminant Analysis,简称LDA)
Fisher模型在统计学和机器学习领域通常指的是Fisher线性判别分析(Fisher's Linear Discriminant Analysis,简称LDA)
|
5月前
|
机器学习/深度学习 数据可视化 搜索推荐
【python机器学习】python电商数据K-Means聚类分析可视化(源码+数据集+报告)【独一无二】
【python机器学习】python电商数据K-Means聚类分析可视化(源码+数据集+报告)【独一无二】
228 0
|
7月前
|
机器学习/深度学习 分布式计算 算法
在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)
【6月更文挑战第28天】在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)、数据规模与特性(大数据可能适合分布式算法或深度学习)、性能需求(准确性、速度、可解释性)、资源限制(计算与内存)、领域知识应用以及实验验证(交叉验证、模型比较)。迭代过程包括数据探索、模型构建、评估和优化,结合业务需求进行决策。
64 0
|
7月前
|
机器学习/深度学习 算法 数据可视化
技术心得记录:机器学习笔记之聚类算法层次聚类HierarchicalClustering
技术心得记录:机器学习笔记之聚类算法层次聚类HierarchicalClustering
69 0