Spark MLlib中KMeans聚类算法的解析和应用

本文涉及的产品
EMR Serverless StarRocks,5000CU*H 48000GB*H
简介: 聚类算法是机器学习中的一种无监督学习算法,它在数据科学领域应用场景很广泛,比如基于用户购买行为、兴趣等来构建推荐系统。

本文转自公众号:大数据学习与分享
原文链接


聚类算法是机器学习中的一种无监督学习算法,它在数据科学领域应用场景很广泛,比如基于用户购买行为、兴趣等来构建推荐系统。

核心思想可以理解为,在给定的数据集中(数据集中的每个元素有可被观察的n个属性),使用聚类算法将数据集划分为k个子集,并且要求每个子集内部的元素之间的差异度尽可能低,而不同子集元素的差异度尽可能高。简而言之,就是通过聚类算法处理给定的数据集,将具有相同或类似的属性(特征)的数据划分为一组,并且不同组之间的属性相差会比较大。

K-Means算法是聚类算法中应用比较广泛的一种聚类算法,比较容易理解且易于实现。

"标准" K-Means算法

KMeans算法的基本思想是随机给定K个初始簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值或者满足已定条件。主要分为4个步骤:

1. 为要聚类的点寻找聚类中心,比如随机选择K个点作为初始聚类中心
2. 计算每个点到聚类中心的距离,将每个点划分到离该点最近的聚类中去
3. 计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心
4. 反复执行第2步和第3步,直到聚类中心不再改变或者聚类次数达到设定迭代上限或者达到指定的容错范围

示例图:

image.png

KMeans算法在做聚类分析的过程中主要有两个难题:初始聚类中心的选择和聚类个数K的选择。

Spark MLlib对KMeans的实现分析

Spark MLlib针对"标准"KMeans的问题,在实现自己的KMeans上主要做了如下核心优化:

1. 选择合适的初始中心点

Spark MLlib在初始中心点的选择上,有两种算法:

随机选择:依据给的种子seed,随机选择K个随机中心点
k-means||:默认的算法

val RANDOM = "random"
val K_MEANS_PARALLEL = "k-means||"

2. 计算样本属于哪一个中心点时对距离计算的优化

假设中心点是(a1,b1),要计算的点是(a2,b2),那么Spark MLlib采取的计算方法是(记为lowerBoundOfSqDist):

image.png

对比欧几里得距离(记为EuclideanDist):
image.png

可轻易证明lowerBoundOfSqDist小于或等于EuclideanDist,并且计算lowerBoundOfSqDist很方便,只需处理中心点和要计算的点的L2范数。那么在实际处理中就分两种情况:

  • 当lowerBoundOfSqDist大于"最近距离"(之前计算好的,记为bestdistance),那么可以推导欧式距离也大于bestdistance,不需要计算欧式距离,省去了很多计算工作
  • 当lowerBoundOfSqDist小于bestdistance,则会调用fastSquaredDistance进行距离的快速计算

关于fastSquaredDistance:

首先计算一个精度:
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
   // 精度满足squared distance期望的精度
   // val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
   // 2.0 * dot(v1, v2)为2(a1*a2 + b1*b2)可以利用之前计算的L2范数
   sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
} else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
   val dotValue = dot(v1, v2)
   sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
   val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
     (sqDist + EPSILON)
   if (precisionBound2 > precision) {
      sqDist = Vectors.sqdist(v1, v2)
   }
} else {
  sqDist = Vectors.sqdist(v1, v2)
}
//精度不满足要求时,则进行Vectors.sqdist(v1, v2)的处理,即原始的距离计算

Spark MLlib中KMeans相关源码分析

基于mllib包下的KMeans相关源码涉及的类和方法(ml包下与下面略有不同,比如涉及到的fit方法):

  1. KMeans类和伴生对象
  2. train方法:根据设置的KMeans聚类参数,构建KMeans聚类,并执行run方法进行训练
  3. run方法:主要调用runAlgorithm方法进行聚类中心点等的核心计算,返回KMeansModel
  4. initialModel:可以直接设置KMeansModel作为初始化聚类中心选择,也支持随机和k-means || 生成中心点
  5. predict:预测样本属于哪个"类"
  6. computeCost:通过计算数据集中所有的点到最近中心点的平方和来衡量聚类效果。一般同样的迭代次数,cost值越小,说明聚类效果越好。

注意:该方法在Spark 2.4.X版本已经过时,并且会在Spark 3.0.0被移除,具体取代方法可以查看ClusteringEvaluator

主要看一下train和runAlgorithm的核心源码:
def train(
      // 数据样本
      data: RDD[Vector],
      // 聚类数量
      k: Int,
      // 最大迭代次数
      maxIterations: Int,
      // 初始化中心,支持"random"或者"k-means||"
      initializationMode: String,
      // 初始化时的随机种子
      seed: Long): KMeansModel = {
  new KMeans().setK(k)
      .setMaxIterations(maxIterations)
      .setInitializationMode(initializationMode)
      .setSeed(seed)
      .run(data)
}
/**
   * Implementation of K-Means algorithm.
   */
  private def runAlgorithm( data: RDD[VectorWithNorm],
      instr: Option[Instrumentation]): KMeansModel = {

    val sc = data.sparkContext

    val initStartTime = System.nanoTime()

    val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)

    val centers = initialModel match {
      case Some(kMeansCenters) =>
        kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
      case None =>
        if (initializationMode == KMeans.RANDOM) {
          // random
          initRandom(data)
        } else {
          // k-means||
          initKMeansParallel(data, distanceMeasureInstance)
        }
    }
    val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
    logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

    var converged = false
    var cost = 0.0
    var iteration = 0

    val iterationStartTime = System.nanoTime()

    instr.foreach(_.logNumFeatures(centers.head.vector.size))

    // Execute iterations of Lloyd's algorithm until converged
    // Kmeans迭代执行,计算每个样本属于哪个中心点,中心点累加的样本值以及计数。然后根据中心点的所有样本数据进行中心点的更新,并且比较更新前的数值,根据两者距离判断是否完成
    //迭代次数小于最大迭代次数,并行计算的中心点还没有收敛
    while (iteration < maxIterations && !converged) {
      // 损失值累加器
      val costAccum = sc.doubleAccumulator
      // 广播中心点
      val bcCenters = sc.broadcast(centers)

      // Find the new centers
      val collected = data.mapPartitions { points =>
        // 当前聚类中心
        val thisCenters = bcCenters.value
        // 中心点的维度
        val dims = thisCenters.head.vector.size

        val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
        val counts = Array.fill(thisCenters.length)(0L)

        points.foreach { point =>
          // 通过当前的聚类中心点,找出最近的聚类中心点
          // findClosest是为了计算bestDistance,参考上述Spark对距离计算的优化
          val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
          costAccum.add(cost)
          distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
          counts(bestCenter) += 1
        }

        counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
      }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
        axpy(1.0, sum2, sum1)
        (sum1, count1 + count2)
      }.collectAsMap()

      if (iteration == 0) {
        instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
      }

      val newCenters = collected.mapValues { case (sum, count) =>
        distanceMeasureInstance.centroid(sum, count)
      }

      bcCenters.destroy(blocking = false)

      // Update the cluster centers and costs
      converged = true
      newCenters.foreach { case (j, newCenter) =>
        if (converged &&
          !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
          // 距离大于,则说明中心点位置改变
          converged = false
        }
        // 更新中心点
        centers(j) = newCenter
      }

      cost = costAccum.value
      iteration += 1
    }

    val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
    logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

    if (iteration == maxIterations) {
      logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
    } else {
      logInfo(s"KMeans converged in $iteration iterations.")
    }

    logInfo(s"The cost is $cost.")

    new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
  }

Spark MLlib的KMeans应用示例

1.准备数据

诺丹姆吉本斯主教中学(Notre Dame-Bishop Gibbons School)      71     0       0    283047.0        13289.0
海景基督高中(Ocean View Christian Academy)      45     0     0       276403.0       13289.0 
卡弗里学院(Calvary Baptist Academy)        58       0       0       227567.0       13289.0
...

2.示例代码

//将加载的rdd数据,每一条变成一个向量,整个数据集变成矩阵
val parsedata = rdd.map { case Row(schoolid, schoolname, locationid, school_type, zs, fee, byj) =>
   //"特征因子":学校位置id,学校类型,住宿方式,学费,备用金
   val features = Array[Double](locationid.toString.toDouble, school_type.toString.toDouble, zs.toString.toDouble, fee.toString.toDouble, byj.toString.toDouble)
    //将数组变成机器学习中的向量
    Vectors.dense(features)
  }.cache() //默认缓存到内存中,可以调用persist()指定缓存到哪

  //用kmeans对样本向量进行训练得到模型
  //聚类中心
  val numclusters = List(3, 6, 9)
  //指定最大迭代次数
  val numIters = List(10, 15, 20)
  var bestModel: Option[KMeansModel] = None
  var bestCluster = 0
  var bestIter = 0
  val bestRmse = Double.MaxValue
  for (c <- numclusters; i <- numIters) {
    val model = KMeans.train(parsedata, c, i)
    //集内均方差总和(WSSSE),一般可以通过增加类簇的个数 k 来减小误差,一般越小越好(有可能出现过拟合)
    val d = model.computeCost(parsedata)
    println("选择K:" + (c, i, d))
    if (d < bestRmse) {
      bestModel = Some(model)
      bestCluster = c
      bestIter = i
    }
  }
  println("best:" + (bestCluster, bestIter, bestModel.get.computeCost(parsedata)))
  //用模型对我们的数据进行预测
  val resrdd = df.map { case Row(schoolid, schoolname, locationid, school_type, zs, fee, byj) =>
  //提取到每一行的特征值
  val features = Array[Double](locationid.toString.toDouble, school_type.toString.toDouble, zs.toString.toDouble, fee.toString.toDouble, byj.toString.toDouble)
   //将特征值转换成特征向量
   val linevector = Vectors.dense(features)
   //将向量输入model中进行预测,得到预测值
   val prediction = bestModel.get.predict(linevector)

   //返回每一行结果((sid,sname),所属类别)
   ((schoolid.toString, schoolname.toString), prediction)
 }

 //中心点
 /*val centers: Array[linalg.Vector] = model.clusterCenters
 centers.foreach(println)*/

 //按照所属"类别"分组,并根据"类别"排序,然后保存到数据库
 // saveData2Mysql是封装好的保存数据到mysql的方法
 resrdd.groupBy(_._2).sortBy(_._1).foreachPartition(saveData2Mysql(_))

上述示例只是一个简单的demo,实际应用中会更复杂,牵涉到数据的预处理,比如对数据进行量化、归一化,以及如何调参以获取最优训练模型。


阿里巴巴开源大数据技术团队成立Apache Spark中国技术社区,定期推送精彩案例,技术专家直播,问答区近万人Spark技术同学在线提问答疑,只为营造纯粹的Spark氛围,欢迎钉钉扫码加入!image.png
对开源大数据和感兴趣的同学可以加小编微信(下图二维码,备注“进群”)进入技术交流微信群。image.png
Apache Spark技术交流社区公众号,微信扫一扫关注
image.png

相关文章
|
3月前
|
分布式计算 数据处理 Apache
Spark和Flink的区别是什么?如何选择?都应用在哪些行业?
【10月更文挑战第10天】Spark和Flink的区别是什么?如何选择?都应用在哪些行业?
347 1
|
3月前
|
算法 前端开发 数据处理
小白学python-深入解析一位字符判定算法
小白学python-深入解析一位字符判定算法
56 0
|
3月前
|
存储 算法 Java
解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用
在Java中,Set接口以其独特的“无重复”特性脱颖而出。本文通过解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用。
61 3
|
3天前
|
存储 算法 安全
基于红黑树的局域网上网行为控制C++ 算法解析
在当今网络环境中,局域网上网行为控制对企业和学校至关重要。本文探讨了一种基于红黑树数据结构的高效算法,用于管理用户的上网行为,如IP地址、上网时长、访问网站类别和流量使用情况。通过红黑树的自平衡特性,确保了高效的查找、插入和删除操作。文中提供了C++代码示例,展示了如何实现该算法,并强调其在网络管理中的应用价值。
|
27天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
149 30
|
7天前
|
存储 监控 算法
企业内网监控系统中基于哈希表的 C# 算法解析
在企业内网监控系统中,哈希表作为一种高效的数据结构,能够快速处理大量网络连接和用户操作记录,确保网络安全与效率。通过C#代码示例展示了如何使用哈希表存储和管理用户的登录时间、访问IP及操作行为等信息,实现快速的查找、插入和删除操作。哈希表的应用显著提升了系统的实时性和准确性,尽管存在哈希冲突等问题,但通过合理设计哈希函数和冲突解决策略,可以确保系统稳定运行,为企业提供有力的安全保障。
|
1月前
|
存储 算法
深入解析PID控制算法:从理论到实践的完整指南
前言 大家好,今天我们介绍一下经典控制理论中的PID控制算法,并着重讲解该算法的编码实现,为实现后续的倒立摆样例内容做准备。 众所周知,掌握了 PID ,就相当于进入了控制工程的大门,也能为更高阶的控制理论学习打下基础。 在很多的自动化控制领域。都会遇到PID控制算法,这种算法具有很好的控制模式,可以让系统具有很好的鲁棒性。 基本介绍 PID 深入理解 (1)闭环控制系统:讲解 PID 之前,我们先解释什么是闭环控制系统。简单说就是一个有输入有输出的系统,输入能影响输出。一般情况下,人们也称输出为反馈,因此也叫闭环反馈控制系统。比如恒温水池,输入就是加热功率,输出就是水温度;比如冷库,
257 15
|
3月前
|
搜索推荐 算法
插入排序算法的平均时间复杂度解析
【10月更文挑战第12天】 插入排序是一种简单直观的排序算法,通过不断将未排序元素插入到已排序部分的合适位置来完成排序。其平均时间复杂度为$O(n^2)$,适用于小规模或部分有序的数据。尽管效率不高,但在特定场景下仍具优势。
|
2月前
|
算法 Linux 定位技术
Linux内核中的进程调度算法解析####
【10月更文挑战第29天】 本文深入剖析了Linux操作系统的心脏——内核中至关重要的组成部分之一,即进程调度机制。不同于传统的摘要概述,我们将通过一段引人入胜的故事线来揭开进程调度算法的神秘面纱,展现其背后的精妙设计与复杂逻辑,让读者仿佛跟随一位虚拟的“进程侦探”,一步步探索Linux如何高效、公平地管理众多进程,确保系统资源的最优分配与利用。 ####
75 4
|
2月前
|
缓存 负载均衡 算法
Linux内核中的进程调度算法解析####
本文深入探讨了Linux操作系统核心组件之一——进程调度器,着重分析了其采用的CFS(完全公平调度器)算法。不同于传统摘要对研究背景、方法、结果和结论的概述,本文摘要将直接揭示CFS算法的核心优势及其在现代多核处理器环境下如何实现高效、公平的资源分配,同时简要提及该算法如何优化系统响应时间和吞吐量,为读者快速构建对Linux进程调度机制的认知框架。 ####