题目说明
编写一个udf,输入这个数组之后按多列输出
题解
udtf其实是udf里面比较少自己去写的东西,所以反而是盲区,这种题目就是摸过的就觉得简单,所谓难者不会,会者不难
代码会放在最后,我说一下精髓部分!
凡是UDF编写,关键点是了解计算特征,udtf其实关键点就是输入一行,可以输出多行,不管里头怎么折腾,反正只要是Java代码写的,输入一行的话我们无非都是一个输入参数,输出多行要么是数组,要么是集合。
输入输出参数说明
输入部分,其实就是select 函数的时候给的那一排参数
比如select udf(1,2,3) 那么args(0)=1,args(1)=2,args(3)=3
题目给的是数组,数组在输入参数的时候我们当成复杂参数来对待,这里不能混淆
比如select udf([1,2,3],[4,5,6])的时候,对应的就是args(0)=[1,2,3],args(1)=[4,5,6]
题目中的是select udf([1,2,3])的形式
输出部分,正如我们说的,不是类型就是数组类型,udtf中输出的其实是一个数组的类型,有一个元素就
forward(array)会调用一次,在process里面调用多次的时候就是多个结果输出
初始化部分
初始化部分其实是给一个表头,也就是结果中的默认列,因为可以输出多行嘛,需要有列名,一个是列名,一个是列的类型:
对于结果的话:
源码
最后,附上代码源码
package org.apache.spark.udf; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import com.google.common.collect.Lists; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; public class MultiplyRow extends GenericUDTF { @Override public void close() throws HiveException { } @Override public StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException { if (args.length != 1) { throw new UDFArgumentLengthException("ExplodeMap takes only one argument"); } if (args[0].getCategory() != ObjectInspector.Category.LIST) { throw new UDFArgumentException("ExplodeMap takes array as a parameter"); } ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("row"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); fieldNames.add("value"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override public void process(Object[] args) throws HiveException { List<String> ls= (ArrayList) args[0]; for (int i = 0; i < ls.size(); i++) { try { String[] result=new String[2]; result[0]="row"+String.valueOf(i); result[1]=ls.get(i); forward(result); } catch (Exception e) { continue; } } } }
这部分是测试代码
package org.apache.spark; import org.apache.spark.sql.SparkSession; import java.io.File; public class MultiplyRowTest { public static void main(String[] args) { String warehouseLocation = new File("spark-warehouse").getAbsolutePath(); SparkSession spark = SparkSession .builder() .appName("MultiplyRowTest") .config("spark.sql.warehouse.dir", warehouseLocation) .master("local[*]") .enableHiveSupport() .getOrCreate(); // spark.sql("CREATE TEMPORARY FUNCTION myudf as 'org.apache.spark.udf.UserDefinedUDTF'"); spark.sql("create temporary function udtf as 'org.apache.spark.udf.MultiplyRow'"); spark.sql("select udtf(split('a,b,c',',')) ").show(); // spark.sql("select 'a' as c1, myudf('a,b,c') as array ").show(); } }