关联形式(Join Types)都有哪些
我个人习惯还是从源码里面定义入手,一方面如果有调整,大家知道怎么去查,另一方面来说,没有什么比起源码的更加官方的定义了。SparkSQL中的关于JOIN的定义位于
org.apache.spark.sql.catalyst.plans.JoinType,按照包的划分,JOIN其实是执行计划的一部分。
具体的定义可以在JoinType的伴生对象中apply方法有构造。
object JoinType { def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter case "leftouter" | "left" => LeftOuter case "rightouter" | "right" => RightOuter case "leftsemi" | "semi" => LeftSemi case "leftanti" | "anti" => LeftAnti case "cross" => Cross case _ => val supported = Seq( "inner", "outer", "full", "fullouter", "full_outer", "leftouter", "left", "left_outer", "rightouter", "right", "right_outer", "leftsemi", "left_semi", "semi", "leftanti", "left_anti", "anti", "cross") throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") } }
当然,我们其实可以很清楚看出来为什么我们平时说left outer join 和left join 其实是一样的。
关联形式 | 关键字 |
内关联 | inner |
外关联 | outer、full、fullouter、full_outer |
左关联 | leftouter、left、left_outer |
右关联 | rightouter、right、right_outer |
左半关联 | leftsemi、left_semi、semi |
左逆关联 | leftanti、left_anti、anti |
交叉连接(笛卡尔积) | cross |
前面几个大家比较熟悉,我说一下后面三个,后面三个其实用得挺多的,只是我们经常是
用一些其他语法表达而已,因为后面我们演示数据,提前建立表
import spark.implicits._ import org.apache.spark.sql.DataFrame // 学生表 val seq = Seq((1, "小明", 28, "男","二班"), (2, "小丽", 22, "女","四班"), (3, "阿虎", 24, "男","三班"), (5, "张强", 18, "男","四班")) val students: DataFrame = seq.toDF("id", "name", "age", "gender","class") students.show(10,false); students.createTempView("students") // 班级表 val seq2 = Seq(("三班",3),("四班",4),("三班",1)) val classes:DataFrame = seq2.toDF("class_name", "id") classes.show(10,false) classes.createTempView("classes")
+---+----+---+------+-----+ |id |name|age|gender|class| +---+----+---+------+-----+ |1 |小明|28 |男 |二班 | |2 |小丽|22 |女 |四班 | |3 |阿虎|24 |男 |三班 | |5 |张强|18 |男 |四班 | +---+----+---+------+-----+ +---+----+---+------+-----+ | id|name|age|gender|class| +---+----+---+------+-----+ | 2|小丽| 22| 女| 四班| | 3|阿虎| 24| 男| 三班| | 5|张强| 18| 男| 四班| +---+----+---+------+-----+
left_semi join
这两个刚好是一对,一个是要剔除,一个是要保留
我们经常有这种需求,我们要查询班级表中存在的学生姓名,我们会这样写SQL
select id,name,age,gender,class from students where class in (select class_name from classes group by class_name)
是一种存在需求,转换成我们的left_semi就是
select id,name,age,gender,class from students left semi join classes on students.class=classes.class_name
当然,也不会忘记了我们的代码形式
val leftsemiDF: DataFrame = students.join(classes, students("class") === classes("class_name"), "leftsemi")
三次运算结果都是一致的,因为班级信息里面没有二班,所以只有三班和4班的信息,这一类需求是我们实现exists 和 一些where 条件中in的时候大量使用
+---+----+---+------+-----+ | id|name|age|gender|class| +---+----+---+------+-----+ | 2|小丽| 22| 女| 四班| | 3|阿虎| 24| 男| 三班| | 5|张强| 18| 男| 四班| +---+----+---+------+-----+
left_anti join
left_anti其实是和semi相反的工作前面是实现了保留,这个则是去掉,我们其实简单换成left_anti 可以看到类似的效果,我们希望看到不在班级表里面的学生,我们会这样写sql
select id,name,age,gender,class from students where class not in (select class_name from classes group by class_name)
转换成我们的anti join则是
select id,name,age,gender,class from students left anti join classes on students.class=classes.class_name
当然,也有代码版本
val leftantiDF: DataFrame = students.join(classes, students("class") === classes("class_name"), "leftanti")
结果如下,只有2班的小明了:
+---+----+---+------+-----+ | id|name|age|gender|class| +---+----+---+------+-----+ | 1|小明| 28| 男| 二班| +---+----+---+------+-----+
cross join
这个其实是笛卡尔积,我们经常忘记写join条件的时候,这种就是笛卡尔积了,
因为这个操作是很容易把程序搞崩的,所以要加上配置
spark.sql.crossJoin.enabled=true
spark.conf.set("spark.sql.crossJoin.enabled", "true") students.join(classes).show(10,false)
结果如下:
+---+----+---+------+-----+----------+---+ |id |name|age|gender|class|class_name|id | +---+----+---+------+-----+----------+---+ |1 |小明|28 |男 |二班 |三班 |3 | |1 |小明|28 |男 |二班 |四班 |4 | |1 |小明|28 |男 |二班 |三班 |1 | |2 |小丽|22 |女 |四班 |三班 |3 | |2 |小丽|22 |女 |四班 |四班 |4 | |2 |小丽|22 |女 |四班 |三班 |1 | |3 |阿虎|24 |男 |三班 |三班 |3 | |3 |阿虎|24 |男 |三班 |四班 |4 | |3 |阿虎|24 |男 |三班 |三班 |1 | |5 |张强|18 |男 |四班 |三班 |3 | +---+----+---+------+-----+----------+---+
JOIN的实现机制
Spark中的JOIN对应BaseJoinExec的五个子类,他们分别是 BroadcastHashJoinExec、BroadcastNestedLoopJoinExec、ShuffledHashJoinExec、SortMergeJoinExec、CartesianProductExec,源码关系如下:
可能大家平时没有太关注这里头有啥联系,但是当我们把这些摆在一起的时候我们其实很明显发现Broadcast和Hash就出现了个,我们可以大胆猜测,这里头有必然的联系,其实Broadcast和Shuffled其实是数据分发的形式,SortMergeJoinExec其实也是通过Shuffle的分发,只是类取名字的时候没有写成ShuffledSortMergeJoinExec 看着是有点长吧,还一点就是走了归并排序其实就是不会走广播了。我们按照分发方式可以整理出一个小表格:
分发方式 | 关键字 |
无分发 | CartesianProductExec |
Broadcast | BroadcastNestedLoopJoinExec |
Broadcast | BroadcastHashJoinExec |
Shuffled | ShuffledHashJoinExec |
Shuffled | SortMergeJoinExec |
CartesianProductExec 无分发
为了搞明白计算原理,我们通过源码来研究研究。首其实可以想得到,笛卡尔积投影是直接把数据的所有行都按照交叉膨胀,这个事情直接在Map端组合完成分发就好了,代码里面其实也是这样子的,我们一起看看,我把CartesianProductExec计算部分拿出来。
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) // Create an iterator from rowArray def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator() val resultIter = for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( resultIter, rowArray.clear()) } }
我们要抠关键部分,其实核心点就是,x,y各自其实就是作为两边的元素,我们理解一下我们操作集合的时候yide的操作。
val resultIter = for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y)
这个其实就是scala中实现二重循环的操作,spark上面的逻辑可以类比下面的逻辑,我把学生join笛卡尔积的操作还原如下:
val students=Array((1,"小明"),(2,"小丽")) val classes= Array("三班","二班") val result= for(student <-students; clazz <- classes) yield (student,clazz) result.foreach(print)
结果如下:
((1,小明),三班)((1,小明),二班)((2,小丽),三班)((2,小丽),二班)
发现没有,其实就是我们需要的结果了,注意哦,Spark源码就是这样的,是不是信心倍增。
BroadcastNestedLoopJoinExec
这个其实就是字面上的含义,广播+嵌套循环实现join,我们一直在说,广播其实是一种分发方式,在我们之前的文章也有说到,其实广播来说,我们在rdd执行的时候,就直接可以当成拿到本地变量而已,我还是把核心代码拿出来:
protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() val resultRdd = (joinType, buildSide) match { case (_: InnerLike, _) => innerJoin(broadcastedRelation) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) case (LeftSemi, _) => leftExistenceJoin(broadcastedRelation, exists = true) case (LeftAnti, _) => leftExistenceJoin(broadcastedRelation, exists = false) case (_: ExistenceJoin, _) => existenceJoin(broadcastedRelation) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter */ defaultJoin(broadcastedRelation) } val numOutputRows = longMetric("numOutputRows") resultRdd.mapPartitionsWithIndexInternal { (index, iter) => val resultProj = genResultProjection resultProj.initialize(index) iter.map { r => numOutputRows += 1 resultProj(r) } }
最前面 broadcastedRelation,其实就是从广播的变量中获取到数据,这个就是广播的操作了,剩下的就是NestedLoop的事情了,我们注意到代中的(joinType, buildSide) match条件,是按照操作类型不同去实现,我们一起看看innser join 的操作:
private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow streamedIter.flatMap { streamedRow => val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) if (condition.isDefined) { joinedRows.filter(boundCondition) } else { joinedRows } } } }
这个操作的意思是把集合打平,关键点就是streamedIter.flatMap()中的操作,这部分就是嵌套循环,这个也就是为什么叫做NestedLoop的原因,其实就是嵌套循环的意思,我把这个操作按照我们数据集等价实现一遍:
val students=Array((1,"小明"),(2,"小丽")) val classes= Array("三班","二班") students.flatMap(student=>{ classes.map(clazz=>{ print(student._1,student._2,clazz) }) })
结果如下:
(1,小明,三班)(1,小明,二班)(2,小丽,三班)(2,小丽,二班)
我们很明显看得出时间复杂度是M*N,看到这里。
BroadcastHashJoinExec
有了前面的基础,对于Broadcast类型的操作,我们可以进一步归纳,三部曲
1、从广播变量中取值
2、完成关联操作
3、输出结果
有了这些操作,我们可以预判代码了
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() if (isNullAwareAntiJoin) { //Anti join是尝试解决子查询反嵌套(Subquery Unnesting)中和NULl值相关的各种问题,这里不展开,也是 //为了代码少一些 } else { streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } } }
广播方式就是从广播变量中获取我们要的结果
val broadcastRelation = buildPlan.executeBroadcastHashedRelation
我们需要关注else 里面的内容,val hashed = broadcastRelation.value.asReadOnlyCopy()其实就是拿到了我们hash的清单,处理之后会在后面的join操作中进一步处理,我们需要深入JOIN内部
protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, numOutputRows: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) case _: ExistenceJoin => existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"HashJoin should not take $x as the JoinType") } val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 resultProj(r) } }
JOIN内部的主流程是按照JOIN类型各自处理,返回JOIN之后的结果,我们还是以Inner为例,分析一下实现的过程
private def innerJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() if (hashedRelation == EmptyHashedRelation) { Iterator.empty } else if (hashedRelation.keyIsUnique) { streamIter.flatMap { srow => joinRow.withLeft(srow) val matched = hashedRelation.getValue(joinKeys(srow)) if (matched != null) { Some(joinRow.withRight(matched)).filter(boundCondition) } else { None } } } else { streamIter.flatMap { srow => joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches != null) { matches.map(joinRow.withRight).filter(boundCondition) } else { Seq.empty } } } }
我们关注到核心操作,从hash表中取到我们想要的元素,实现join,同样的,用我们的数据集实现一下
import scala.collection.mutable.HashMap val students = Array((1, "小明", "002"), (2, "小丽", "003")) val hashClass = HashMap[String, String]("003" -> "三班", "002" -> "二班") students.map(student => (student, hashClass.get(student._3))).foreach(print)
因为本身的逻辑就行整个遍历一遍重新组合结果,时间复杂度是O(N)
((1,小明,002),Some(二班))((2,小丽,003),Some(三班))
阶段性总结
前面介绍的三种JOIN方式其实已经可以完全实现所有的JOIN操作了,但是这些操作有一个特点,作为主表我们可以分成不同的Partition上面执行,但是从表我们其实是清一色作为Executor本地方式执行的,因为我们的Task是分布在很多集群上运行的,所有我们为了让所有的节点都有这份数据,所有是往所有节点都分发一次,这也是为啥叫做广播的原因。
这些方式在数据量不大的时候是很高效的,这个数据量的规模可以是10万级到百万级不等,也就是说其实可以控制的,源码中的定义如下,我们可以看到其实后面给我们有一个默认值10MB
val AUTO_BROADCASTJOIN_THRESHOLD = buildConf("spark.sql.autoBroadcastJoinThreshold") .doc("...解释信息") .version("1.1.0") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10MB")
有这么一句描述 By setting this value to -1 broadcasting can be disabled,就是给-1可以关闭这个数值,这种10M的概念一般就是几万以内,和你本身字段数量还有数据内容有关系,如果不满足可以调整。
我们做如下设置,就可以不打开广播了
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
还一个情况是大家网络上看到说广播方式数据量小的时候其实适合,大的时候慢。其实这里在生产中的实际情况是,数据量比较大的时候他是跑不出来,不是慢的问题,因为一个任务有时候执行下来本身就是要大半个小时,但是跑一个小时之后来个失败,这个才是忍受不了的,生产环境下游又有依赖,所以这个时候我们不再是追求快了,而是都在求神拜佛就行只要跑出来就可以,哪怕是再等等,但是别失败。这才是实际的情况。基于此,我们才会引入ShuffledHashJoinExec和SortMergeJoinExec的计算方式,大家注意,这两种方式是为了实现更大规模数据量的JOIN而产生的,就整体时间效率上远不如前面的方式,但是最大的好处是可以出结果呀,加上企业实际生产过程中数据量其实都是很庞大的,所以这两种方式才是在生产上大量存在的操作方式。
ShuffledHashJoinExec
ShuffledHashJoinExec方式其实在join操作中获取从表信息还是从Hash中获取,这里的差别在于,我们本来需要做广播的表太大了,所以我们需要把广播的表通Shuffle的方式把一个大表分解成小表生成hash,是怎么个原理呢。首先我们看实现源码:
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) joinType match { case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows) case _ => join(streamIter, hashed, numOutputRows) } } }
ShuffledHashJoinExec的后面部分其实是复用了前面的BroadcastHashJoinExec的操作,实现来说还是需要获取到一个HashedRelation,这个也是我们的构建HashTable的部分,后面就是会在本地执行join操作了,主要差别是来自获取HashedRelation的时候,前者是从广播变量中获取,这里不再是了:
def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") val buildTime = longMetric("buildTime") val start = System.nanoTime() val context = TaskContext.get() val relation = HashedRelation( iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager(), // Full outer join needs support for NULL key in HashedRelation. allowsNullKey = joinType == FullOuter) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. context.addTaskCompletionListener[Unit](_ => relation.close()) relation }
核心部分其实是来自,这部分其实就是一个Shuffle的迭代结果,也就是说我们获取这部分Hash的时候是按照buildBoundKeys作为范围内的获取,不再是整个表的范围了
val relation = HashedRelation( iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager(), // Full outer join needs support for NULL key in HashedRelation. allowsNullKey = joinType == FullOuter)
当然,从理解上还是简单的,因为Spark在实现设计上非常巧妙,我们还是可以当成一个Map来看待,只不过这个时候的数据范围是局部地区。
我们同样做点实战例子来理解计算过程,因为比起前面还是多一些步骤,我们画图理解一下:
学生信息作为主表目标就是找到和班级编号匹配的主表,在此之前,班级编号做一次partition操作,会把数据分布在不同分区里面去,与此同时,班级信息也进行partition,这样相同班级编号的信息会落到同一个partition中去,针对单独的partition就是实现和之前的HashJoin一样的操作了。要注意的是partition就是Shuffle实现的,我们都知道shuffle的时候是需要定义一个hashpartition的操作,所以这个操作其实是有两次的hash,第一次把数据进行分区,第二次,实现关联操作。我们把整个过程用代码模拟出来,我这里分区操作不会那么复杂,仅仅按照%2的方式分发:
//对班级编号进行分区 def partition(key:String):Integer={ val keyNum= Integer.valueOf(key) keyNum%2 }
接下来是模拟Shuffle的操作,因为我们是需要把班级信息通过shuffle的方式分发,我们实现了把多个HashMap分发到了partitions内部
def shuffleHash(classes:Array[(String, String)]):mutable.HashMap[Integer,mutable.HashMap[String,(String, String)]]={ val partitions=new mutable.HashMap[Integer,mutable.HashMap[String,(String, String)]](); for(clazz <- classes){ val hashKey=partition(clazz._1) val map=partitions.getOrElse(hashKey,mutable.HashMap[String,(String, String)]()) map += (clazz._1->clazz) partitions += hashKey->map } partitions }
最后,我们实现真正的JOIN操作:
val students = Array[(Integer,String,String)]((1, "小黑", "001"),(2, "小明", "002"), (3, "小丽", "003"),(4, "小红", "004")) val classes: Array[(String, String)] = Array(("001", "一班"), ("002", "一班"), ("003", "三班"), ("004", "四班")) val partitions= shuffleHash(classes) students.map(student=>{ val partitionId= partition(student._3) val classHashMap= partitions.get(partitionId).get val clazz= classHashMap.get(student._3) println("分区:"+partitionId,"学生:"+student,"班级:"+clazz.get) })
查看结果:
(分区:1,学生:(1,小黑,001),班级:(001,一班)) (分区:0,学生:(2,小明,002),班级:(002,一班)) (分区:1,学生:(3,小丽,003),班级:(003,三班)) (分区:0,学生:(4,小红,004),班级:(004,四班))
这个也能理解,为什么Shuffle操作一定要把相同的数据分发到相同的分区里面去,其实这个就是一个分治的算法,是不是感觉也没那么神秘了~~
SortMergeJoinExec
我们其实注意到,算法的不断演进就是为了实现不同数据规模的情况,SortMergeJoinExec的话我们同样适用,也许是名字会太长的关系,如果我们补充完全,按照数据生成方式,我们可以命名为ShuffledSortMergeJoinExec,这就是说需要通过Shuffle的方式生成,同时在合并的时候走的是SortMerge,我们可以从类的继承关系上看出来。
case class SortMergeJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], left: SparkPlan, right: SparkPlan, isSkewJoin: Boolean = false) extends ShuffledJoin {
这个是因为,不管是前面的Hash方式做Shuffle,我们都用到了一个classHashMap的类,这个操作时间复杂度O(1),快得很,但是架不住数据多呀,数据量比较庞大的时候classHashMap其实是转不下的,所以我还是那句话,这种时候其实是跑不出来,因为内存会打爆,并不是真的就慢的原因,因为按照慢来说至少等一等可以,但是如果超过了本身这种处理方式的能力的话,整个计算进行不下去,我才迫使我们不得不找到新的计算方式,可能性能上确实没那么快,但是至少跑得出来。
val classHashMap= partitions.get(partitionId).get 受内存局限 val clazz= classHashMap.get(student._3)
SortMerge这个是大家的印象应该是既熟悉又陌生,因为大量的作业都是用这种方式的,所以老是可以看到,但是陌生是因为也不知道里头什么个机制,我们今天来倒腾倒腾。前面我也说过Shuffle是给我们把数据按照相同的HashPartition算法分发到相同的分区中去,这个和前面操作是完全一样的,SortMerge就是解决怎么把两边的数据JOIN在一起的问题,所以我们重点关注一下SortMerge的操作。还是要关注源码中的执行操作,我们需要注意到,这一次的输入内容和前面是有差别的leftIter, rightIter是两个迭代器的输入操作了,不再是Hash,所以本身这种操作上算法是解决两个list的合并操作
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => Predicate.create(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } } ... }
doExecute()下面的代码其实是比较长的,这里刚好一起学习学习这种长代码的研究思路,
首先,我们要知道doExecute(): RDD[InternalRow] 一定需要返回一个RDD,这个是我们需要的结果,那么在接下来的代码就一定有地方在返回,代码比较长的时候我们先折叠一下长的部分:
折叠之后我们就可以看清楚执行的主干了,整个的逻辑是对应不同的JOIN类型做了一个
new RowIterator()的操作,而RowIterator.toScala就可以返回我们RDD[InternalRow] 了,顺着这个思路,我们其实可以梳理出不同JOIN对应的而RowIterator类型
JOIN类型 | RowIterator实现 |
InnerLike | new RowIterator 重写了advanceNext和 getRow |
LeftOuter | LeftOuterIterator |
RightOuter | RightOuterIterator |
FullOuter | FullOuterIterator |
LeftSemi | new RowIterator 重写了advanceNext和 getRow |
LeftAnti | new RowIterator 重写了advanceNext和 getRow |
ExistenceJoin | new RowIterator 重写了advanceNext和 getRow |
ok到了这一步,我们对整个的返回就很清楚了,所以整个实现的触发入口其实是
RowIterator.toScala我们进一步查看RowIterator的实现
abstract class RowIterator { def advanceNext(): Boolean def getRow: InternalRow def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) }
RowIterator是一个抽象类,toScala里面其实是new RowIteratorToScala(this)做了这件事情,具体RowIteratorToScala的实现如下:
private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { private [this] var hasNextWasCalled: Boolean = false private [this] var _hasNext: Boolean = false override def hasNext: Boolean = { // Idempotency: if (!hasNextWasCalled) { _hasNext = rowIter.advanceNext() hasNextWasCalled = true } _hasNext } override def next(): InternalRow = { if (!hasNext) throw QueryExecutionErrors.noSuchElementExceptionError() hasNextWasCalled = false rowIter.getRow } }
看到这里,我们其实很清晰了,整个过程其实就是在实现一个Iterator操作,我们都知道,RDD本身就是一个Iterator在做RDD的迭代时候,核心方法就是触发hasNext和 next方法,这两个方法也是Iterator的入口,我们很清楚看到在hasNext调用了advanceNext操作,而在
next中我们从传入的rowIter: RowIterator调用了getRow操作,这两个方法就是在前面做了重写的地方,整个实现结构我们就串联起来了。接下来我们只需要关注advanceNext()的实现:
while (rightMatchesIterator != null) { if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null rightMatchesIterator = null return false } } joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows += 1 return true } } false }
看到这里,我们可以清楚看到整个merge逻辑了,rightMatchesIterator和smjScanner就是我们做join之后左右两边的Iterator, 而joinRow(currentLeftRow, rightMatchesIterator.next())就是把当前获取到的数据行拼接起来,这里是因为前提是我们的数据行在两边都是已经做了排序的,所以只需要把迭代器往前面移动即可,并不需要做Hash的时候的查找操作。还是按照我们的优良传统,我来一波小代码实现一把:
def merge(students: Array[(Integer,String,String)],classes: Array[(String, String)]):Unit= { val stuIterator=students.iterator val claIterator=classes.iterator var curRow=claIterator.next()//当前行 var stuRow=stuIterator.next() //获取当前行的数据 while (stuRow !=null ) { var needNext=true if(stuRow._3==curRow._1){ //学生的班号和班级编号相等,就join起来 println("学生:"+stuRow,"班级:"+curRow) }else{ //匹配不上的情况,班号往后移动 curRow=claIterator.next() needNext=false } if(needNext ){ if(stuIterator.hasNext){ stuRow=stuIterator.next() }else{ stuRow=null } } } }
merge实现的是对已经排序之后的集合进行排序,所以我们在输入的时候要保证一下顺序
val students = Array[(Integer,String,String)]((1, "小黑", "001"),(2, "小明", "002"), (3, "小丽", "003"),(4, "小红", "004")) val classes: Array[(String, String)] = Array(("001", "一班"), ("002", "二班"), ("003", "三班"), ("004", "四班")) merge(students,classes)
结果如下:
(学生:(1,小黑,001),班级:(001,一班)) (学生:(2,小明,002),班级:(002,二班)) (学生:(3,小丽,003),班级:(003,三班)) (学生:(4,小红,004),班级:(004,四班))
我们图形化展示一下
总结
小代码虽然有点糙,但是那个是精华^^