Spark MLlib中KMeans聚类算法的解析和应用

本文涉及的产品
EMR Serverless StarRocks,5000CU*H 48000GB*H
简介: 聚类算法是机器学习中的一种无监督学习算法,它在数据科学领域应用场景很广泛,比如基于用户购买行为、兴趣等来构建推荐系统。

本文转自公众号:大数据学习与分享
原文链接


聚类算法是机器学习中的一种无监督学习算法,它在数据科学领域应用场景很广泛,比如基于用户购买行为、兴趣等来构建推荐系统。

核心思想可以理解为,在给定的数据集中(数据集中的每个元素有可被观察的n个属性),使用聚类算法将数据集划分为k个子集,并且要求每个子集内部的元素之间的差异度尽可能低,而不同子集元素的差异度尽可能高。简而言之,就是通过聚类算法处理给定的数据集,将具有相同或类似的属性(特征)的数据划分为一组,并且不同组之间的属性相差会比较大。

K-Means算法是聚类算法中应用比较广泛的一种聚类算法,比较容易理解且易于实现。

"标准" K-Means算法

KMeans算法的基本思想是随机给定K个初始簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值或者满足已定条件。主要分为4个步骤:

1. 为要聚类的点寻找聚类中心,比如随机选择K个点作为初始聚类中心
2. 计算每个点到聚类中心的距离,将每个点划分到离该点最近的聚类中去
3. 计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心
4. 反复执行第2步和第3步,直到聚类中心不再改变或者聚类次数达到设定迭代上限或者达到指定的容错范围

示例图:

image.png

KMeans算法在做聚类分析的过程中主要有两个难题:初始聚类中心的选择和聚类个数K的选择。

Spark MLlib对KMeans的实现分析

Spark MLlib针对"标准"KMeans的问题,在实现自己的KMeans上主要做了如下核心优化:

1. 选择合适的初始中心点

Spark MLlib在初始中心点的选择上,有两种算法:

随机选择:依据给的种子seed,随机选择K个随机中心点
k-means||:默认的算法

val RANDOM = "random"
val K_MEANS_PARALLEL = "k-means||"

2. 计算样本属于哪一个中心点时对距离计算的优化

假设中心点是(a1,b1),要计算的点是(a2,b2),那么Spark MLlib采取的计算方法是(记为lowerBoundOfSqDist):

image.png

对比欧几里得距离(记为EuclideanDist):
image.png

可轻易证明lowerBoundOfSqDist小于或等于EuclideanDist,并且计算lowerBoundOfSqDist很方便,只需处理中心点和要计算的点的L2范数。那么在实际处理中就分两种情况:

  • 当lowerBoundOfSqDist大于"最近距离"(之前计算好的,记为bestdistance),那么可以推导欧式距离也大于bestdistance,不需要计算欧式距离,省去了很多计算工作
  • 当lowerBoundOfSqDist小于bestdistance,则会调用fastSquaredDistance进行距离的快速计算

关于fastSquaredDistance:

首先计算一个精度:
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
   // 精度满足squared distance期望的精度
   // val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
   // 2.0 * dot(v1, v2)为2(a1*a2 + b1*b2)可以利用之前计算的L2范数
   sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
} else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
   val dotValue = dot(v1, v2)
   sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
   val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
     (sqDist + EPSILON)
   if (precisionBound2 > precision) {
      sqDist = Vectors.sqdist(v1, v2)
   }
} else {
  sqDist = Vectors.sqdist(v1, v2)
}
//精度不满足要求时,则进行Vectors.sqdist(v1, v2)的处理,即原始的距离计算

Spark MLlib中KMeans相关源码分析

基于mllib包下的KMeans相关源码涉及的类和方法(ml包下与下面略有不同,比如涉及到的fit方法):

  1. KMeans类和伴生对象
  2. train方法:根据设置的KMeans聚类参数,构建KMeans聚类,并执行run方法进行训练
  3. run方法:主要调用runAlgorithm方法进行聚类中心点等的核心计算,返回KMeansModel
  4. initialModel:可以直接设置KMeansModel作为初始化聚类中心选择,也支持随机和k-means || 生成中心点
  5. predict:预测样本属于哪个"类"
  6. computeCost:通过计算数据集中所有的点到最近中心点的平方和来衡量聚类效果。一般同样的迭代次数,cost值越小,说明聚类效果越好。

注意:该方法在Spark 2.4.X版本已经过时,并且会在Spark 3.0.0被移除,具体取代方法可以查看ClusteringEvaluator

主要看一下train和runAlgorithm的核心源码:
def train(
      // 数据样本
      data: RDD[Vector],
      // 聚类数量
      k: Int,
      // 最大迭代次数
      maxIterations: Int,
      // 初始化中心,支持"random"或者"k-means||"
      initializationMode: String,
      // 初始化时的随机种子
      seed: Long): KMeansModel = {
  new KMeans().setK(k)
      .setMaxIterations(maxIterations)
      .setInitializationMode(initializationMode)
      .setSeed(seed)
      .run(data)
}
/**
   * Implementation of K-Means algorithm.
   */
  private def runAlgorithm( data: RDD[VectorWithNorm],
      instr: Option[Instrumentation]): KMeansModel = {

    val sc = data.sparkContext

    val initStartTime = System.nanoTime()

    val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)

    val centers = initialModel match {
      case Some(kMeansCenters) =>
        kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
      case None =>
        if (initializationMode == KMeans.RANDOM) {
          // random
          initRandom(data)
        } else {
          // k-means||
          initKMeansParallel(data, distanceMeasureInstance)
        }
    }
    val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
    logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

    var converged = false
    var cost = 0.0
    var iteration = 0

    val iterationStartTime = System.nanoTime()

    instr.foreach(_.logNumFeatures(centers.head.vector.size))

    // Execute iterations of Lloyd's algorithm until converged
    // Kmeans迭代执行,计算每个样本属于哪个中心点,中心点累加的样本值以及计数。然后根据中心点的所有样本数据进行中心点的更新,并且比较更新前的数值,根据两者距离判断是否完成
    //迭代次数小于最大迭代次数,并行计算的中心点还没有收敛
    while (iteration < maxIterations && !converged) {
      // 损失值累加器
      val costAccum = sc.doubleAccumulator
      // 广播中心点
      val bcCenters = sc.broadcast(centers)

      // Find the new centers
      val collected = data.mapPartitions { points =>
        // 当前聚类中心
        val thisCenters = bcCenters.value
        // 中心点的维度
        val dims = thisCenters.head.vector.size

        val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
        val counts = Array.fill(thisCenters.length)(0L)

        points.foreach { point =>
          // 通过当前的聚类中心点,找出最近的聚类中心点
          // findClosest是为了计算bestDistance,参考上述Spark对距离计算的优化
          val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
          costAccum.add(cost)
          distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
          counts(bestCenter) += 1
        }

        counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
      }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
        axpy(1.0, sum2, sum1)
        (sum1, count1 + count2)
      }.collectAsMap()

      if (iteration == 0) {
        instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
      }

      val newCenters = collected.mapValues { case (sum, count) =>
        distanceMeasureInstance.centroid(sum, count)
      }

      bcCenters.destroy(blocking = false)

      // Update the cluster centers and costs
      converged = true
      newCenters.foreach { case (j, newCenter) =>
        if (converged &&
          !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
          // 距离大于,则说明中心点位置改变
          converged = false
        }
        // 更新中心点
        centers(j) = newCenter
      }

      cost = costAccum.value
      iteration += 1
    }

    val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
    logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

    if (iteration == maxIterations) {
      logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
    } else {
      logInfo(s"KMeans converged in $iteration iterations.")
    }

    logInfo(s"The cost is $cost.")

    new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
  }

Spark MLlib的KMeans应用示例

1.准备数据

诺丹姆吉本斯主教中学(Notre Dame-Bishop Gibbons School)      71     0       0    283047.0        13289.0
海景基督高中(Ocean View Christian Academy)      45     0     0       276403.0       13289.0 
卡弗里学院(Calvary Baptist Academy)        58       0       0       227567.0       13289.0
...

2.示例代码

//将加载的rdd数据,每一条变成一个向量,整个数据集变成矩阵
val parsedata = rdd.map { case Row(schoolid, schoolname, locationid, school_type, zs, fee, byj) =>
   //"特征因子":学校位置id,学校类型,住宿方式,学费,备用金
   val features = Array[Double](locationid.toString.toDouble, school_type.toString.toDouble, zs.toString.toDouble, fee.toString.toDouble, byj.toString.toDouble)
    //将数组变成机器学习中的向量
    Vectors.dense(features)
  }.cache() //默认缓存到内存中,可以调用persist()指定缓存到哪

  //用kmeans对样本向量进行训练得到模型
  //聚类中心
  val numclusters = List(3, 6, 9)
  //指定最大迭代次数
  val numIters = List(10, 15, 20)
  var bestModel: Option[KMeansModel] = None
  var bestCluster = 0
  var bestIter = 0
  val bestRmse = Double.MaxValue
  for (c <- numclusters; i <- numIters) {
    val model = KMeans.train(parsedata, c, i)
    //集内均方差总和(WSSSE),一般可以通过增加类簇的个数 k 来减小误差,一般越小越好(有可能出现过拟合)
    val d = model.computeCost(parsedata)
    println("选择K:" + (c, i, d))
    if (d < bestRmse) {
      bestModel = Some(model)
      bestCluster = c
      bestIter = i
    }
  }
  println("best:" + (bestCluster, bestIter, bestModel.get.computeCost(parsedata)))
  //用模型对我们的数据进行预测
  val resrdd = df.map { case Row(schoolid, schoolname, locationid, school_type, zs, fee, byj) =>
  //提取到每一行的特征值
  val features = Array[Double](locationid.toString.toDouble, school_type.toString.toDouble, zs.toString.toDouble, fee.toString.toDouble, byj.toString.toDouble)
   //将特征值转换成特征向量
   val linevector = Vectors.dense(features)
   //将向量输入model中进行预测,得到预测值
   val prediction = bestModel.get.predict(linevector)

   //返回每一行结果((sid,sname),所属类别)
   ((schoolid.toString, schoolname.toString), prediction)
 }

 //中心点
 /*val centers: Array[linalg.Vector] = model.clusterCenters
 centers.foreach(println)*/

 //按照所属"类别"分组,并根据"类别"排序,然后保存到数据库
 // saveData2Mysql是封装好的保存数据到mysql的方法
 resrdd.groupBy(_._2).sortBy(_._1).foreachPartition(saveData2Mysql(_))

上述示例只是一个简单的demo,实际应用中会更复杂,牵涉到数据的预处理,比如对数据进行量化、归一化,以及如何调参以获取最优训练模型。


阿里巴巴开源大数据技术团队成立Apache Spark中国技术社区,定期推送精彩案例,技术专家直播,问答区近万人Spark技术同学在线提问答疑,只为营造纯粹的Spark氛围,欢迎钉钉扫码加入!image.png
对开源大数据和感兴趣的同学可以加小编微信(下图二维码,备注“进群”)进入技术交流微信群。image.png
Apache Spark技术交流社区公众号,微信扫一扫关注
image.png

相关文章
|
11天前
|
存储 缓存 搜索推荐
Lazada淘宝详情API的价值与应用解析
在电商行业,数据是驱动业务增长的核心。Lazada作为东南亚知名电商平台,其商品详情API对电商行业影响深远。本文探讨了Lazada商品详情API的重要性,包括提供全面准确的商品信息、增强平台竞争力、促进销售转化、支持用户搜索和发现需求、数据驱动决策、竞品分析、用户行为研究及提升购物体验。文章还介绍了如何通过Lazada提供的API接口、编写代码及使用第三方工具实现实时数据获取。
29 3
|
15天前
|
存储 算法 Java
解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用
在Java中,Set接口以其独特的“无重复”特性脱颖而出。本文通过解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用。
35 3
|
1天前
|
机器学习/深度学习 人工智能 安全
TPAMI:安全强化学习方法、理论与应用综述,慕工大、同济、伯克利等深度解析
【10月更文挑战第27天】强化学习(RL)在实际应用中展现出巨大潜力,但其安全性问题日益凸显。为此,安全强化学习(SRL)应运而生。近日,来自慕尼黑工业大学、同济大学和加州大学伯克利分校的研究人员在《IEEE模式分析与机器智能汇刊》上发表了一篇综述论文,系统介绍了SRL的方法、理论和应用。SRL主要面临安全性定义模糊、探索与利用平衡以及鲁棒性与可靠性等挑战。研究人员提出了基于约束、基于风险和基于监督学习等多种方法来应对这些挑战。
9 2
|
3天前
|
分布式计算 Java 开发工具
阿里云MaxCompute-XGBoost on Spark 极限梯度提升算法的分布式训练与模型持久化oss的实现与代码浅析
本文介绍了XGBoost在MaxCompute+OSS架构下模型持久化遇到的问题及其解决方案。首先简要介绍了XGBoost的特点和应用场景,随后详细描述了客户在将XGBoost on Spark任务从HDFS迁移到OSS时遇到的异常情况。通过分析异常堆栈和源代码,发现使用的`nativeBooster.saveModel`方法不支持OSS路径,而使用`write.overwrite().save`方法则能成功保存模型。最后提供了完整的Scala代码示例、Maven配置和提交命令,帮助用户顺利迁移模型存储路径。
|
5天前
|
测试技术 开发者 Python
深入浅出:Python中的装饰器解析与应用###
【10月更文挑战第22天】 本文将带你走进Python装饰器的世界,揭示其背后的魔法。我们将一起探索装饰器的定义、工作原理、常见用法以及如何自定义装饰器,让你的代码更加简洁高效。无论你是Python新手还是有一定经验的开发者,相信这篇文章都能为你带来新的启发和收获。 ###
8 1
|
9天前
|
传感器 监控 安全
|
9天前
|
数据中心
|
9天前
|
人工智能 资源调度 数据可视化
【AI应用落地实战】智能文档处理本地部署——可视化文档解析前端TextIn ParseX实践
2024长沙·中国1024程序员节以“智能应用新生态”为主题,吸引了众多技术大咖。合合信息展示了“智能文档处理百宝箱”的三大工具:可视化文档解析前端TextIn ParseX、向量化acge-embedding模型和文档解析测评工具markdown_tester,助力智能文档处理与知识管理。
|
10天前
|
存储 Java API
详细解析HashMap、TreeMap、LinkedHashMap等实现类,帮助您更好地理解和应用Java Map。
【10月更文挑战第19天】深入剖析Java Map:不仅是高效存储键值对的数据结构,更是展现设计艺术的典范。本文从基本概念、设计艺术和使用技巧三个方面,详细解析HashMap、TreeMap、LinkedHashMap等实现类,帮助您更好地理解和应用Java Map。
29 3
|
10天前
|
存储 算法 搜索推荐
这些算法在实际应用中有哪些具体案例呢
【10月更文挑战第19天】这些算法在实际应用中有哪些具体案例呢
20 1

推荐镜像

更多