背景
本文基于spark 3.2.0
由于codegen涉及到的知识点比较多,我们先来说清楚code"""""",我们暂且叫做code代码块
scala 字符串插值
要想搞清楚spark的code代码块,就得现搞清楚scala 字符串插值。
scala 字符串插值是2.10.0版本引用进来的新语法规则,可以直接允许使用者将变量引用直接插入到字符串中,如下:
val name = 'LI' println(s"My name is $name") 输出: My name is LI
这种资料很多,大家自行查阅资料理解。
code代码块
因为这块代码比较复杂,直接拿出例子来运行:
直接找到spark CastSuite.scala 第215行如下:
test("cast string to boolean II") { checkEvaluation(cast("abc", BooleanType), null)
之后在javaCode.scala 输出对应的想要debug的值,如下:
*/ def code(args: Any*): Block = { sc.checkLengths(args) if (sc.parts.length == 0) { EmptyBlock } else { args.foreach { case _: ExprValue | _: Inline | _: Block => case _: Boolean | _: Byte | _: Int | _: Long | _: Float | _: Double | _: String => case other => throw QueryExecutionErrors.cannotInterpolateClassIntoCodeBlockError(other) } val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) // scalasytle:off println(s"code: $codeParts") println(s"blockInputs: $blockInputs") // scalasytle:on CodeBlock(codeParts, blockInputs) } }
这样,运行后我们会发现,如下结果:
code: ArrayBuffer( if (org.apache.spark.sql.catalyst.util.StringUtils.isTrueString(, )) { , = true; } else if (org.apache.spark.sql.catalyst.util.StringUtils.isFalseString(, )) { , = false; } else { isNull_0 = true; } ) blockInputs: ArrayBuffer(((UTF8String) references[0] /* literal */), value_0, ((UTF8String) references[0] /* literal */), value_0) result: if (org.apache.spark.sql.catalyst.util.StringUtils.isTrueString(((UTF8String) references[0] /* literal */))) { value_0 = true; } else if (org.apache.spark.sql.catalyst.util.StringUtils.isFalseString(((UTF8String) references[0] /* literal */))) { value_0 = false; } else { isNull_0 = true; } ...
而这段代码刚好和Cast.scala中的 castToBooleanCode方法是一一对应的的:
private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => val castFailureCode = if (ansiEnabled) { s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c);" } else { s"$evNull = true;" } val result = code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { $evPrim = false; } else { $castFailureCode } """ // scalastyle:off println(s"result: $result") // scalastyle:on result
也就是说spark自定义的ExprValue类型的值被替换了(其实是Inline/Block/ExprValue这三种类型的值都会被替换,只不过这里没有体现),如下:
而输出的result结果就是拼接完后的完整字符串。
我们这里是为了debug,才会把结果和对应的片段打印出来,
而在spark真正处理的时候,返回的是ExprCode类型的值,在真正需要代码生成的时候,才会调用的toString的方法生成对应的字符串
code代码块之间的连接
但是我们在Cast.scala的方法中我们看到的doGenCode是先调用child.genCode的方法的:
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) ev.copy(code = eval.code + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) }
那子节点的ExprCode怎么和父节点的ExprCode连接起来的呢?
其实这个和写代码的思路是一样的,每个子节点返回的ExprCode类型的值,都会对应为该方法体的的实现代码,返回值(包括了类型),spark额外增加了一个是否为null,如下:
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
其中code是对应的方法体的实现代码,
isNull 是对应的是否为null,
value 代表的返回值
至于为什么会额外增加一个是否为null,还是和写代码的逻辑是一样的,因为只有不为空的情况下,代码才会正常的往下运行:
protected[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { val javaType = JavaCode.javaType(resultType) code""" boolean $resultIsNull = $inputIsNull; $javaType $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } """ }
这里的!$inputIsNull判断,只有不为空了才进行下一步的转换操作,要不然会抛出异常。
这样把子节点的结果作为父节点的入参传入给对应的方法,这样生成的代码完全符合编码的逻辑,这样这部分也就说完了,当然这部分也是代码生成的重中之重,理解了这部分,代码生成这块就差不多了,其他的就是各个部分的实现,用心去看即可。