手撕SparkSQL五大JOIN的底层机制

简介: 手撕SparkSQL五大JOIN的底层机制

关联形式(Join Types)都有哪些

我个人习惯还是从源码里面定义入手,一方面如果有调整,大家知道怎么去查,另一方面来说,没有什么比起源码的更加官方的定义了。SparkSQL中的关于JOIN的定义位于

org.apache.spark.sql.catalyst.plans.JoinType,按照包的划分,JOIN其实是执行计划的一部分。

具体的定义可以在JoinType的伴生对象中apply方法有构造。

object JoinType {
  def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match {
    case "inner" => Inner
    case "outer" | "full" | "fullouter" => FullOuter
    case "leftouter" | "left" => LeftOuter
    case "rightouter" | "right" => RightOuter
    case "leftsemi" | "semi" => LeftSemi
    case "leftanti" | "anti" => LeftAnti
    case "cross" => Cross
    case _ =>
      val supported = Seq(
        "inner",
        "outer", "full", "fullouter", "full_outer",
        "leftouter", "left", "left_outer",
        "rightouter", "right", "right_outer",
        "leftsemi", "left_semi", "semi",
        "leftanti", "left_anti", "anti",
        "cross")

      throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
        "Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
  }
}

当然,我们其实可以很清楚看出来为什么我们平时说left outer join 和left join 其实是一样的。

关联形式 关键字
内关联 inner
外关联 outer、full、fullouter、full_outer
左关联 leftouter、left、left_outer
右关联 rightouter、right、right_outer
左半关联 leftsemi、left_semi、semi
左逆关联 leftanti、left_anti、anti
交叉连接(笛卡尔积) cross

前面几个大家比较熟悉,我说一下后面三个,后面三个其实用得挺多的,只是我们经常是

用一些其他语法表达而已,因为后面我们演示数据,提前建立表

import spark.implicits._
    import org.apache.spark.sql.DataFrame

    // 学生表
    val seq = Seq((1, "小明", 28, "男","二班"), (2, "小丽", 22, "女","四班"), (3, "阿虎", 24, "男","三班"), (5, "张强", 18, "男","四班"))
    val students: DataFrame = seq.toDF("id", "name", "age", "gender","class")
    students.show(10,false);
    students.createTempView("students")
    // 班级表
    val seq2 = Seq(("三班",3),("四班",4),("三班",1))
    val classes:DataFrame = seq2.toDF("class_name", "id")
    classes.show(10,false)
    classes.createTempView("classes")
+---+----+---+------+-----+
|id |name|age|gender|class|
+---+----+---+------+-----+
|1  |小明|28 |男    |二班 |
|2  |小丽|22 |女    |四班 |
|3  |阿虎|24 |男    |三班 |
|5  |张强|18 |男    |四班 |
+---+----+---+------+-----+

+---+----+---+------+-----+
| id|name|age|gender|class|
+---+----+---+------+-----+
|  2|小丽| 22|    女| 四班|
|  3|阿虎| 24|    男| 三班|
|  5|张强| 18|    男| 四班|
+---+----+---+------+-----+

left_semi join

这两个刚好是一对,一个是要剔除,一个是要保留

我们经常有这种需求,我们要查询班级表中存在的学生姓名,我们会这样写SQL

select id,name,age,gender,class from students where class in (select class_name from classes group by class_name)

是一种存在需求,转换成我们的left_semi就是

select id,name,age,gender,class from students left semi join classes on students.class=classes.class_name

当然,也不会忘记了我们的代码形式

val leftsemiDF: DataFrame = students.join(classes, students("class") === classes("class_name"), "leftsemi")

三次运算结果都是一致的,因为班级信息里面没有二班,所以只有三班和4班的信息,这一类需求是我们实现exists 和 一些where 条件中in的时候大量使用


+---+----+---+------+-----+
| id|name|age|gender|class|
+---+----+---+------+-----+
|  2|小丽| 22|    女| 四班|
|  3|阿虎| 24|    男| 三班|
|  5|张强| 18|    男| 四班|
+---+----+---+------+-----+

left_anti join

left_anti其实是和semi相反的工作前面是实现了保留,这个则是去掉,我们其实简单换成left_anti 可以看到类似的效果,我们希望看到不在班级表里面的学生,我们会这样写sql

select id,name,age,gender,class from students where class not in (select class_name from classes group by class_name)

转换成我们的anti join则是

select id,name,age,gender,class from students left anti join classes on students.class=classes.class_name

当然,也有代码版本

val leftantiDF: DataFrame = students.join(classes, students("class") === classes("class_name"), "leftanti")

结果如下,只有2班的小明了:

+---+----+---+------+-----+
| id|name|age|gender|class|
+---+----+---+------+-----+
|  1|小明| 28|    男| 二班|
+---+----+---+------+-----+

cross join

这个其实是笛卡尔积,我们经常忘记写join条件的时候,这种就是笛卡尔积了,

因为这个操作是很容易把程序搞崩的,所以要加上配置

spark.sql.crossJoin.enabled=true

spark.conf.set("spark.sql.crossJoin.enabled", "true")
students.join(classes).show(10,false)

结果如下:

+---+----+---+------+-----+----------+---+
|id |name|age|gender|class|class_name|id |
+---+----+---+------+-----+----------+---+
|1  |小明|28 |男    |二班 |三班      |3  |
|1  |小明|28 |男    |二班 |四班      |4  |
|1  |小明|28 |男    |二班 |三班      |1  |
|2  |小丽|22 |女    |四班 |三班      |3  |
|2  |小丽|22 |女    |四班 |四班      |4  |
|2  |小丽|22 |女    |四班 |三班      |1  |
|3  |阿虎|24 |男    |三班 |三班      |3  |
|3  |阿虎|24 |男    |三班 |四班      |4  |
|3  |阿虎|24 |男    |三班 |三班      |1  |
|5  |张强|18 |男    |四班 |三班      |3  |
+---+----+---+------+-----+----------+---+

JOIN的实现机制

Spark中的JOIN对应BaseJoinExec的五个子类,他们分别是 BroadcastHashJoinExec、BroadcastNestedLoopJoinExec、ShuffledHashJoinExec、SortMergeJoinExec、CartesianProductExec,源码关系如下:

可能大家平时没有太关注这里头有啥联系,但是当我们把这些摆在一起的时候我们其实很明显发现Broadcast和Hash就出现了个,我们可以大胆猜测,这里头有必然的联系,其实Broadcast和Shuffled其实是数据分发的形式,SortMergeJoinExec其实也是通过Shuffle的分发,只是类取名字的时候没有写成ShuffledSortMergeJoinExec 看着是有点长吧,还一点就是走了归并排序其实就是不会走广播了。我们按照分发方式可以整理出一个小表格:

分发方式 关键字
无分发 CartesianProductExec
Broadcast BroadcastNestedLoopJoinExec
Broadcast BroadcastHashJoinExec
Shuffled ShuffledHashJoinExec
Shuffled SortMergeJoinExec

CartesianProductExec 无分发

为了搞明白计算原理,我们通过源码来研究研究。首其实可以想得到,笛卡尔积投影是直接把数据的所有行都按照交叉膨胀,这个事情直接在Map端组合完成分发就好了,代码里面其实也是这样子的,我们一起看看,我把CartesianProductExec计算部分拿出来。

override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
    val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold)
    val partition = split.asInstanceOf[CartesianPartition]
    rdd2.iterator(partition.s2, context).foreach(rowArray.add)
    // Create an iterator from rowArray
    def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
    val resultIter =
      for (x <- rdd1.iterator(partition.s1, context);
           y <- createIter()) yield (x, y)
    CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
      resultIter, rowArray.clear())
  }
}

我们要抠关键部分,其实核心点就是,x,y各自其实就是作为两边的元素,我们理解一下我们操作集合的时候yide的操作。

 val resultIter =
      for (x <- rdd1.iterator(partition.s1, context);
           y <- createIter()) yield (x, y)

这个其实就是scala中实现二重循环的操作,spark上面的逻辑可以类比下面的逻辑,我把学生join笛卡尔积的操作还原如下:

 val students=Array((1,"小明"),(2,"小丽"))
    val classes= Array("三班","二班")
    val result= for(student <-students; clazz <- classes) yield (student,clazz)
    result.foreach(print)

结果如下:

((1,小明),三班)((1,小明),二班)((2,小丽),三班)((2,小丽),二班)

发现没有,其实就是我们需要的结果了,注意哦,Spark源码就是这样的,是不是信心倍增。

BroadcastNestedLoopJoinExec

这个其实就是字面上的含义,广播+嵌套循环实现join,我们一直在说,广播其实是一种分发方式,在我们之前的文章也有说到,其实广播来说,我们在rdd执行的时候,就直接可以当成拿到本地变量而已,我还是把核心代码拿出来:

protected override def doExecute(): RDD[InternalRow] = {
    val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()

    val resultRdd = (joinType, buildSide) match {
      case (_: InnerLike, _) =>
        innerJoin(broadcastedRelation)
      case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
        outerJoin(broadcastedRelation)
      case (LeftSemi, _) =>
        leftExistenceJoin(broadcastedRelation, exists = true)
      case (LeftAnti, _) =>
        leftExistenceJoin(broadcastedRelation, exists = false)
      case (_: ExistenceJoin, _) =>
        existenceJoin(broadcastedRelation)
      case _ =>
        /**
         * LeftOuter with BuildLeft
         * RightOuter with BuildRight
         * FullOuter
         */
        defaultJoin(broadcastedRelation)
    }

    val numOutputRows = longMetric("numOutputRows")
    resultRdd.mapPartitionsWithIndexInternal { (index, iter) =>
      val resultProj = genResultProjection
      resultProj.initialize(index)
      iter.map { r =>
        numOutputRows += 1
        resultProj(r)
      }
    }

最前面 broadcastedRelation,其实就是从广播的变量中获取到数据,这个就是广播的操作了,剩下的就是NestedLoop的事情了,我们注意到代中的(joinType, buildSide) match条件,是按照操作类型不同去实现,我们一起看看innser join 的操作:

private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
    streamed.execute().mapPartitionsInternal { streamedIter =>
      val buildRows = relation.value
      val joinedRow = new JoinedRow

      streamedIter.flatMap { streamedRow =>
        val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r))
        if (condition.isDefined) {
          joinedRows.filter(boundCondition)
        } else {
          joinedRows
        }
      }
    }
  }

这个操作的意思是把集合打平,关键点就是streamedIter.flatMap()中的操作,这部分就是嵌套循环,这个也就是为什么叫做NestedLoop的原因,其实就是嵌套循环的意思,我把这个操作按照我们数据集等价实现一遍:

val students=Array((1,"小明"),(2,"小丽"))
val classes= Array("三班","二班")
students.flatMap(student=>{
     classes.map(clazz=>{
       print(student._1,student._2,clazz)
     })
   })

结果如下:

(1,小明,三班)(1,小明,二班)(2,小丽,三班)(2,小丽,二班)

我们很明显看得出时间复杂度是M*N,看到这里。

BroadcastHashJoinExec

有了前面的基础,对于Broadcast类型的操作,我们可以进一步归纳,三部曲

1、从广播变量中取值

2、完成关联操作

3、输出结果

有了这些操作,我们可以预判代码了

  protected override def doExecute(): RDD[InternalRow] = {
  val numOutputRows = longMetric("numOutputRows")
    val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
    if (isNullAwareAntiJoin) {
        //Anti join是尝试解决子查询反嵌套(Subquery Unnesting)中和NULl值相关的各种问题,这里不展开,也是
        //为了代码少一些
    } else {
      streamedPlan.execute().mapPartitions { streamedIter =>
        val hashed = broadcastRelation.value.asReadOnlyCopy()
        TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
        join(streamedIter, hashed, numOutputRows)
      }
    }
    }

广播方式就是从广播变量中获取我们要的结果

val broadcastRelation = buildPlan.executeBroadcastHashedRelation

我们需要关注else 里面的内容,val hashed = broadcastRelation.value.asReadOnlyCopy()其实就是拿到了我们hash的清单,处理之后会在后面的join操作中进一步处理,我们需要深入JOIN内部

protected def join(
      streamedIter: Iterator[InternalRow],
      hashed: HashedRelation,
      numOutputRows: SQLMetric): Iterator[InternalRow] = {

    val joinedIter = joinType match {
      case _: InnerLike =>
        innerJoin(streamedIter, hashed)
      case LeftOuter | RightOuter =>
        outerJoin(streamedIter, hashed)
      case LeftSemi =>
        semiJoin(streamedIter, hashed)
      case LeftAnti =>
        antiJoin(streamedIter, hashed)
      case _: ExistenceJoin =>
        existenceJoin(streamedIter, hashed)
      case x =>
        throw new IllegalArgumentException(
          s"HashJoin should not take $x as the JoinType")
    }

    val resultProj = createResultProjection
    joinedIter.map { r =>
      numOutputRows += 1
      resultProj(r)
    }
  }

JOIN内部的主流程是按照JOIN类型各自处理,返回JOIN之后的结果,我们还是以Inner为例,分析一下实现的过程

private def innerJoin(
      streamIter: Iterator[InternalRow],
      hashedRelation: HashedRelation): Iterator[InternalRow] = {
    val joinRow = new JoinedRow
    val joinKeys = streamSideKeyGenerator()

    if (hashedRelation == EmptyHashedRelation) {
      Iterator.empty
    } else if (hashedRelation.keyIsUnique) {
      streamIter.flatMap { srow =>
        joinRow.withLeft(srow)
        val matched = hashedRelation.getValue(joinKeys(srow))
        if (matched != null) {
          Some(joinRow.withRight(matched)).filter(boundCondition)
        } else {
          None
        }
      }
    } else {
      streamIter.flatMap { srow =>
        joinRow.withLeft(srow)
        val matches = hashedRelation.get(joinKeys(srow))
        if (matches != null) {
          matches.map(joinRow.withRight).filter(boundCondition)
        } else {
          Seq.empty
        }
      }
    }
  }

我们关注到核心操作,从hash表中取到我们想要的元素,实现join,同样的,用我们的数据集实现一下

  import scala.collection.mutable.HashMap
    val students = Array((1, "小明", "002"), (2, "小丽", "003"))
    val hashClass = HashMap[String, String]("003" -> "三班", "002" -> "二班")
    students.map(student => (student, hashClass.get(student._3))).foreach(print)

因为本身的逻辑就行整个遍历一遍重新组合结果,时间复杂度是O(N)

((1,小明,002),Some(二班))((2,小丽,003),Some(三班))

阶段性总结

前面介绍的三种JOIN方式其实已经可以完全实现所有的JOIN操作了,但是这些操作有一个特点,作为主表我们可以分成不同的Partition上面执行,但是从表我们其实是清一色作为Executor本地方式执行的,因为我们的Task是分布在很多集群上运行的,所有我们为了让所有的节点都有这份数据,所有是往所有节点都分发一次,这也是为啥叫做广播的原因。

这些方式在数据量不大的时候是很高效的,这个数据量的规模可以是10万级到百万级不等,也就是说其实可以控制的,源码中的定义如下,我们可以看到其实后面给我们有一个默认值10MB

    val AUTO_BROADCASTJOIN_THRESHOLD = buildConf("spark.sql.autoBroadcastJoinThreshold")
    .doc("...解释信息")
    .version("1.1.0")
    .bytesConf(ByteUnit.BYTE)
    .createWithDefaultString("10MB")

有这么一句描述 By setting this value to -1 broadcasting can be disabled,就是给-1可以关闭这个数值,这种10M的概念一般就是几万以内,和你本身字段数量还有数据内容有关系,如果不满足可以调整。

我们做如下设置,就可以不打开广播了

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")

还一个情况是大家网络上看到说广播方式数据量小的时候其实适合,大的时候慢。其实这里在生产中的实际情况是,数据量比较大的时候他是跑不出来,不是慢的问题,因为一个任务有时候执行下来本身就是要大半个小时,但是跑一个小时之后来个失败,这个才是忍受不了的,生产环境下游又有依赖,所以这个时候我们不再是追求快了,而是都在求神拜佛就行只要跑出来就可以,哪怕是再等等,但是别失败。这才是实际的情况。基于此,我们才会引入ShuffledHashJoinExec和SortMergeJoinExec的计算方式,大家注意,这两种方式是为了实现更大规模数据量的JOIN而产生的,就整体时间效率上远不如前面的方式,但是最大的好处是可以出结果呀,加上企业实际生产过程中数据量其实都是很庞大的,所以这两种方式才是在生产上大量存在的操作方式。

ShuffledHashJoinExec

ShuffledHashJoinExec方式其实在join操作中获取从表信息还是从Hash中获取,这里的差别在于,我们本来需要做广播的表太大了,所以我们需要把广播的表通Shuffle的方式把一个大表分解成小表生成hash,是怎么个原理呢。首先我们看实现源码:

protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      joinType match {
        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
        case _ => join(streamIter, hashed, numOutputRows)
      }
    }
  }

ShuffledHashJoinExec的后面部分其实是复用了前面的BroadcastHashJoinExec的操作,实现来说还是需要获取到一个HashedRelation,这个也是我们的构建HashTable的部分,后面就是会在本地执行join操作了,主要差别是来自获取HashedRelation的时候,前者是从广播变量中获取,这里不再是了:

def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
    val buildDataSize = longMetric("buildDataSize")
    val buildTime = longMetric("buildTime")
    val start = System.nanoTime()
    val context = TaskContext.get()
    val relation = HashedRelation(
      iter,
      buildBoundKeys,
      taskMemoryManager = context.taskMemoryManager(),
      // Full outer join needs support for NULL key in HashedRelation.
      allowsNullKey = joinType == FullOuter)
    buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
    buildDataSize += relation.estimatedSize
    // This relation is usually used until the end of task.
    context.addTaskCompletionListener[Unit](_ => relation.close())
    relation
  }

核心部分其实是来自,这部分其实就是一个Shuffle的迭代结果,也就是说我们获取这部分Hash的时候是按照buildBoundKeys作为范围内的获取,不再是整个表的范围了

val relation = HashedRelation(
     iter,
     buildBoundKeys,
     taskMemoryManager = context.taskMemoryManager(),
     // Full outer join needs support for NULL key in HashedRelation.
     allowsNullKey = joinType == FullOuter)

当然,从理解上还是简单的,因为Spark在实现设计上非常巧妙,我们还是可以当成一个Map来看待,只不过这个时候的数据范围是局部地区。

我们同样做点实战例子来理解计算过程,因为比起前面还是多一些步骤,我们画图理解一下:

学生信息作为主表目标就是找到和班级编号匹配的主表,在此之前,班级编号做一次partition操作,会把数据分布在不同分区里面去,与此同时,班级信息也进行partition,这样相同班级编号的信息会落到同一个partition中去,针对单独的partition就是实现和之前的HashJoin一样的操作了。要注意的是partition就是Shuffle实现的,我们都知道shuffle的时候是需要定义一个hashpartition的操作,所以这个操作其实是有两次的hash,第一次把数据进行分区,第二次,实现关联操作。我们把整个过程用代码模拟出来,我这里分区操作不会那么复杂,仅仅按照%2的方式分发:

//对班级编号进行分区
  def partition(key:String):Integer={
    val keyNum= Integer.valueOf(key)
    keyNum%2
  }

接下来是模拟Shuffle的操作,因为我们是需要把班级信息通过shuffle的方式分发,我们实现了把多个HashMap分发到了partitions内部

 def shuffleHash(classes:Array[(String, String)]):mutable.HashMap[Integer,mutable.HashMap[String,(String, String)]]={
      val partitions=new mutable.HashMap[Integer,mutable.HashMap[String,(String, String)]]();
      for(clazz <- classes){
        val hashKey=partition(clazz._1)
        val map=partitions.getOrElse(hashKey,mutable.HashMap[String,(String, String)]())
        map += (clazz._1->clazz)
        partitions += hashKey->map
      }
      partitions
  }

最后,我们实现真正的JOIN操作:

val students = Array[(Integer,String,String)]((1, "小黑", "001"),(2, "小明", "002"), (3, "小丽", "003"),(4, "小红", "004"))
    val classes: Array[(String, String)] = Array(("001", "一班"), ("002", "一班"), ("003", "三班"), ("004", "四班"))

   val partitions= shuffleHash(classes)
    students.map(student=>{
     val partitionId= partition(student._3)
     val  classHashMap= partitions.get(partitionId).get
     val clazz= classHashMap.get(student._3)
      println("分区:"+partitionId,"学生:"+student,"班级:"+clazz.get)
    })

查看结果:

(分区:1,学生:(1,小黑,001),班级:(001,一班))
(分区:0,学生:(2,小明,002),班级:(002,一班))
(分区:1,学生:(3,小丽,003),班级:(003,三班))
(分区:0,学生:(4,小红,004),班级:(004,四班))

这个也能理解,为什么Shuffle操作一定要把相同的数据分发到相同的分区里面去,其实这个就是一个分治的算法,是不是感觉也没那么神秘了~~

SortMergeJoinExec

我们其实注意到,算法的不断演进就是为了实现不同数据规模的情况,SortMergeJoinExec的话我们同样适用,也许是名字会太长的关系,如果我们补充完全,按照数据生成方式,我们可以命名为ShuffledSortMergeJoinExec,这就是说需要通过Shuffle的方式生成,同时在合并的时候走的是SortMerge,我们可以从类的继承关系上看出来。

case class SortMergeJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan,
    isSkewJoin: Boolean = false) extends ShuffledJoin {

这个是因为,不管是前面的Hash方式做Shuffle,我们都用到了一个classHashMap的类,这个操作时间复杂度O(1),快得很,但是架不住数据多呀,数据量比较庞大的时候classHashMap其实是转不下的,所以我还是那句话,这种时候其实是跑不出来,因为内存会打爆,并不是真的就慢的原因,因为按照慢来说至少等一等可以,但是如果超过了本身这种处理方式的能力的话,整个计算进行不下去,我才迫使我们不得不找到新的计算方式,可能性能上确实没那么快,但是至少跑得出来。

val  classHashMap= partitions.get(partitionId).get 受内存局限
     val clazz= classHashMap.get(student._3)

SortMerge这个是大家的印象应该是既熟悉又陌生,因为大量的作业都是用这种方式的,所以老是可以看到,但是陌生是因为也不知道里头什么个机制,我们今天来倒腾倒腾。前面我也说过Shuffle是给我们把数据按照相同的HashPartition算法分发到相同的分区中去,这个和前面操作是完全一样的,SortMerge就是解决怎么把两边的数据JOIN在一起的问题,所以我们重点关注一下SortMerge的操作。还是要关注源码中的执行操作,我们需要注意到,这一次的输入内容和前面是有差别的leftIter, rightIter是两个迭代器的输入操作了,不再是Hash,所以本身这种操作上算法是解决两个list的合并操作

protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    val spillThreshold = getSpillThreshold
    val inMemoryThreshold = getInMemoryThreshold
    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
      val boundCondition: (InternalRow) => Boolean = {
        condition.map { cond =>
          Predicate.create(cond, left.output ++ right.output).eval _
        }.getOrElse {
          (r: InternalRow) => true
        }
      }
      ...
      }

doExecute()下面的代码其实是比较长的,这里刚好一起学习学习这种长代码的研究思路,

首先,我们要知道doExecute(): RDD[InternalRow] 一定需要返回一个RDD,这个是我们需要的结果,那么在接下来的代码就一定有地方在返回,代码比较长的时候我们先折叠一下长的部分:

折叠之后我们就可以看清楚执行的主干了,整个的逻辑是对应不同的JOIN类型做了一个

new RowIterator()的操作,而RowIterator.toScala就可以返回我们RDD[InternalRow] 了,顺着这个思路,我们其实可以梳理出不同JOIN对应的而RowIterator类型

JOIN类型 RowIterator实现
InnerLike new RowIterator 重写了advanceNext和 getRow
LeftOuter LeftOuterIterator
RightOuter RightOuterIterator
FullOuter FullOuterIterator
LeftSemi new RowIterator 重写了advanceNext和 getRow
LeftAnti new RowIterator 重写了advanceNext和 getRow
ExistenceJoin new RowIterator 重写了advanceNext和 getRow

ok到了这一步,我们对整个的返回就很清楚了,所以整个实现的触发入口其实是

RowIterator.toScala我们进一步查看RowIterator的实现

abstract class RowIterator {
  def advanceNext(): Boolean
  def getRow: InternalRow
  def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
}

RowIterator是一个抽象类,toScala里面其实是new RowIteratorToScala(this)做了这件事情,具体RowIteratorToScala的实现如下:

private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] {
 private [this] var hasNextWasCalled: Boolean = false
 private [this] var _hasNext: Boolean = false
 override def hasNext: Boolean = {
   // Idempotency:
   if (!hasNextWasCalled) {
     _hasNext = rowIter.advanceNext()
     hasNextWasCalled = true
   }
   _hasNext
 }
 override def next(): InternalRow = {
   if (!hasNext) throw QueryExecutionErrors.noSuchElementExceptionError()
   hasNextWasCalled = false
   rowIter.getRow
 }
}

看到这里,我们其实很清晰了,整个过程其实就是在实现一个Iterator操作,我们都知道,RDD本身就是一个Iterator在做RDD的迭代时候,核心方法就是触发hasNext和 next方法,这两个方法也是Iterator的入口,我们很清楚看到在hasNext调用了advanceNext操作,而在

next中我们从传入的rowIter: RowIterator调用了getRow操作,这两个方法就是在前面做了重写的地方,整个实现结构我们就串联起来了。接下来我们只需要关注advanceNext()的实现:

while (rightMatchesIterator != null) {
                if (!rightMatchesIterator.hasNext) {
                  if (smjScanner.findNextInnerJoinRows()) {
                    currentRightMatches = smjScanner.getBufferedMatches
                    currentLeftRow = smjScanner.getStreamedRow
                    rightMatchesIterator = currentRightMatches.generateIterator()
                  } else {
                    currentRightMatches = null
                    currentLeftRow = null
                    rightMatchesIterator = null
                    return false
                  }
                }
                joinRow(currentLeftRow, rightMatchesIterator.next())
                if (boundCondition(joinRow)) {
                  numOutputRows += 1
                  return true
                }
              }
              false
            }

看到这里,我们可以清楚看到整个merge逻辑了,rightMatchesIterator和smjScanner就是我们做join之后左右两边的Iterator, 而joinRow(currentLeftRow, rightMatchesIterator.next())就是把当前获取到的数据行拼接起来,这里是因为前提是我们的数据行在两边都是已经做了排序的,所以只需要把迭代器往前面移动即可,并不需要做Hash的时候的查找操作。还是按照我们的优良传统,我来一波小代码实现一把:

def merge(students: Array[(Integer,String,String)],classes: Array[(String, String)]):Unit= {
    val stuIterator=students.iterator
    val claIterator=classes.iterator
    var curRow=claIterator.next()//当前行
    var stuRow=stuIterator.next()  //获取当前行的数据
    while (stuRow !=null ) {
      var needNext=true
       if(stuRow._3==curRow._1){
         //学生的班号和班级编号相等,就join起来
         println("学生:"+stuRow,"班级:"+curRow)
       }else{
         //匹配不上的情况,班号往后移动
         curRow=claIterator.next()
         needNext=false
       }
      if(needNext ){
        if(stuIterator.hasNext){
          stuRow=stuIterator.next()
        }else{
          stuRow=null
        }
      }
    }
  }

merge实现的是对已经排序之后的集合进行排序,所以我们在输入的时候要保证一下顺序

val students = Array[(Integer,String,String)]((1, "小黑", "001"),(2, "小明", "002"), (3, "小丽", "003"),(4, "小红", "004"))
    val classes: Array[(String, String)] = Array(("001", "一班"), ("002", "二班"), ("003", "三班"), ("004", "四班"))
    merge(students,classes)

结果如下:

(学生:(1,小黑,001),班级:(001,一班))
(学生:(2,小明,002),班级:(002,二班))
(学生:(3,小丽,003),班级:(003,三班))
(学生:(4,小红,004),班级:(004,四班))

我们图形化展示一下

总结

小代码虽然有点糙,但是那个是精华^^

目录
相关文章
|
2月前
|
存储 Java 开发者
Stream原理与执行流程探析
本文简单讲述了Stream原理,并以一段比较简单常见的stream操作代码为例进行讲解。
|
6月前
|
存储 分布式计算 算法
【底层服务/编程功底系列】「大数据算法体系」带你深入分析MapReduce算法 — Shuffle的执行过程
【底层服务/编程功底系列】「大数据算法体系」带你深入分析MapReduce算法 — Shuffle的执行过程
95 0
|
3月前
|
存储 Java
【Java集合类面试二十九】、说一说HashSet的底层结构
HashSet的底层结构是基于HashMap实现的,使用一个初始容量为16和负载因子为0.75的HashMap,其中HashSet元素作为HashMap的key,而value是一个静态的PRESENT对象。
|
新零售 大数据 云计算
二二复制公排开发功能丨二二复制公排系统开发(开发原理)丨二二复制公排源码详细
 新零售的另一个新层次是互联网+技术(大数据、云计算、移动支付等)它可以连接线上和线下,实现全面覆盖,并通过技术提高零售能力。使企业能够更清晰地获得消费者的形象,同时刺激消费者的消费,创造更好的消费者体验。
|
存储 安全 Java
java集合类史上最细讲解 - List篇
从上面的集合框架图可以看到,Java 集合框架主要包括两种类型的容器,一种是集合(Collection),存储一个元素集合,另一种是图(Map),存储键/值对映射。Collection 接口又有 3 种子类型,List、Set 和 Queue,再下面是一些抽象类,最后是具体实现类,常用的有 ArrayList、LinkedList、HashSet、LinkedHashSet、HashMap、LinkedHashMap 等等。
117 0
java集合类史上最细讲解 - List篇
|
安全 Java API
【Java技术指南】「技术盲点」也许你不了解的Map.merge的用法指南
【Java技术指南】「技术盲点」也许你不了解的Map.merge的用法指南
177 0
|
SQL 缓存 关系型数据库
Join原理(2)--连接原理(四十)
Join原理(2)--连接原理(四十)
|
算法 Java Android开发
抽丝剥茧聊Kotlin协程之Job是如何建立结构化并发的双向传播机制关系的
抽丝剥茧聊Kotlin协程之Job是如何建立结构化并发的双向传播机制关系的
抽丝剥茧聊Kotlin协程之Job是如何建立结构化并发的双向传播机制关系的
|
人工智能 算法 Java
为自己搭建一个分布式 IM 系统二【从查找算法聊起】(上)
把一些影响较大的 bug 以及需求比较迫切的 feature 调整了,本次更新的 v1.0.1 版本: 客户端超时自动下线。 新增 AI 模式。 聊天记录查询。 在线用户前缀模糊匹配。 下面谈下几个比较重点的功能。 客户端超时自动下线 这个功能涉及到客户端和服务端的心跳设计,比较有意思,也踩了几个坑;所以准备留到下次单独来聊。