Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF 教程(Java踩坑教学版)

简介:

在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

  • UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
  • UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
  • UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

本篇就手把手教你如何编写UDF和UDAF

先来个简单的UDF

场景:
我们有这样一个文本文件:

1^^d
2^b^d
3^c^d
4^^d

在读取数据的时候,第二列的数据如果为空,需要显示'null',不为空就直接输出它的值。定义完成后,就可以直接在SparkSQL中使用了。

代码为:

package test;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by xinghailong on 2017/2/23.
 */
public class test3 {
    public static void main(String[] args) {
        //创建spark的运行环境
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[2]");
        sparkConf.setAppName("test-udf");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        SQLContext sqlContext = new SQLContext(sc);
        //注册自定义方法
        sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);
        //读取文件
        JavaRDD<String> lines = sc.textFile( "C:\\test-udf.txt" );
        JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));

        List<StructField> structFields = new ArrayList<StructField>();
        structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));
        structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));
        structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true ));
        StructType structType = DataTypes.createStructType( structFields );

        DataFrame test = sqlContext.createDataFrame( rows, structType);
        test.registerTempTable("test");
        
        sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show();
        sc.stop();
    }
}

输出内容为:

+---+----+---+
|  a| _c1|  c|
+---+----+---+
|  1|null|  d|
|  2|   b|  d|
|  3|   c|  d|
|  4|null|  d|
+---+----+---+

其中比较关键的就是这句:

sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);

449064-20170223234609788-1881484163.png

这里我直接用的java8的语法写的,如果是java8之前的版本,需要使用Function2创建匿名函数。

再来个自定义的UDAF—求平均数

先来个最简单的UDAF,求平均数。类似这种的操作有很多,比如最大值,最小值,累加,拼接等等,都可以采用相同的思路来做。

首先是需要定义UDAF函数

package test;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by xinghailong on 2017/2/23.
 */
public class MyAvg extends UserDefinedAggregateFunction {

    @Override
    public StructType inputSchema() {
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField( "field1", DataTypes.StringType, true ));
        return DataTypes.createStructType( structFields );
    }

    @Override
    public StructType bufferSchema() {
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true ));
        structFields.add(DataTypes.createStructField( "field2", DataTypes.IntegerType, true ));
        return DataTypes.createStructType( structFields );
    }

    @Override
    public DataType dataType() {
        return DataTypes.IntegerType;
    }

    @Override
    public boolean deterministic() {
        return false;
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,0);
        buffer.update(1,0);
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        buffer.update(0,buffer.getInt(0)+1);
        buffer.update(1,buffer.getInt(1)+Integer.valueOf(input.getString(0)));
    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0,buffer1.getInt(0)+buffer2.getInt(0));
        buffer1.update(1,buffer1.getInt(1)+buffer2.getInt(1));
    }

    @Override
    public Object evaluate(Row buffer) {
        return buffer.getInt(1)/buffer.getInt(0);
    }
}

使用的时候,需要先注册,然后在spark sql里面就可以直接使用了:

package test;

import com.tgou.standford.misdw.udf.MyAvg;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by xinghailong on 2017/2/23.
 */
public class test4 {
    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[2]");
        sparkConf.setAppName("test");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        SQLContext sqlContext = new SQLContext(sc);

        sqlContext.udf().register("my_avg",new MyAvg());

        JavaRDD<String> lines = sc.textFile( "C:\\test4.txt" );
        JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));

        List<StructField> structFields = new ArrayList<StructField>();
        structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));
        structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));
        StructType structType = DataTypes.createStructType( structFields );

        DataFrame test = sqlContext.createDataFrame( rows, structType);
        test.registerTempTable("test");

        sqlContext.sql("SELECT my_avg(b) FROM test GROUP BY a").show();

        sc.stop();
    }
}

计算的文本内容为:

a^3
a^6
b^2
b^4
b^6

449064-20170223234630038-1164529334.png

再来个无所不能的UDAF

真正的业务场景里面,总会有千奇百怪的需求,比如:

  • 想要按照某个字段分组,取其中的一个最大值
  • 想要按照某个字段分组,对分组内容的数据按照特定字段统计累加
  • 想要按照某个字段分组,针对特定的条件,拼接字符串

再比如一个场景,需要按照某个字段分组,然后分组内的数据,又需要按照某一列进行去重,最后再计算值

  • 1 按照某个字段分组
  • 2 分组校验条件
  • 3 然后处理字段

如果不用UDAF,你要是写spark可能需要这样做:

rdd.groupBy(r->r.xxx)
    .map(t2->{
        HashSet<String> set = new HashSet<>();
        for(Object p : t2._2){
            if(p.getBs() > 0 ){
                map.put(xx,yyy)
            }
        }
        return StringUtils.join(set.toArray(),",");
    });

上面是一段伪码,不保证正常运行哈。

这样写,其实也能应付需求了,但是代码显得略有点丑陋。还是不如SparkSQL看的清晰明了...

所以我们再尝试用SparkSql中的UDAF来一版!

首先需要创建UDAF类

import org.apache.commons.lang.StringUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.*;

import java.util.*;

/**
 *
 * Created by xinghailong on 2017/2/23.
 */
public class ConditionJoinUDAF extends UserDefinedAggregateFunction {
    @Override
    public StructType inputSchema() {
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true ));
        structFields.add(DataTypes.createStructField( "field2", DataTypes.StringType, true ));
        return DataTypes.createStructType( structFields );
    }

    @Override
    public StructType bufferSchema() {
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField( "field", DataTypes.StringType, true ));
        return DataTypes.createStructType( structFields );
    }

    @Override
    public DataType dataType() {
        return DataTypes.StringType;
    }

    @Override
    public boolean deterministic() {//是否强制每次执行的结果相同
        return false;
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {//初始化
        buffer.update(0,"");
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {//相同的executor间的数据合并
        Integer bs = input.getInt(0);
        String field = buffer.getString(0);
        String in = input.getString(1);
        if(bs > 0 && !"".equals(in) && !field.contains(in)){
            field += ","+in;
        }
        buffer.update(0,field);
    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {//不同excutor间的数据合并
        String field1 = buffer1.getString(0);
        String field2 = buffer2.getString(0);
        if(!"".equals(field2)){
            field1 += ","+field2;
        }
        buffer1.update(0,field1);
    }

    @Override
    public Object evaluate(Row buffer) {//根据Buffer计算结果
        return StringUtils.join(Arrays.stream(buffer.getString(0).split(",")).filter(line->!line.equals("")).toArray(),",");
    }
}

拿一个例子坐下实验:

a^1111^2
a^1111^2
a^1111^2
a^1111^2
a^1111^2
a^2222^0
a^3333^1
b^4444^0
b^5555^3
c^6666^0

按照第一列进行分组,不同的第三列值,进行拼接。

package test;

import test.ConditionJoinUDAF;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by xinghailong on 2017/2/23.
 */
public class test2 {
    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[2]");
        sparkConf.setAppName("test");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        SQLContext sqlContext = new SQLContext(sc);

        sqlContext.udf().register("con_join",new ConditionJoinUDAF());

        JavaRDD<String> lines = sc.textFile( "C:\\test2.txt" );
        JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));

        List<StructField> structFields = new ArrayList<StructField>();
        structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));
        structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));
        structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true ));
        StructType structType = DataTypes.createStructType( structFields );

        DataFrame test = sqlContext.createDataFrame( rows, structType);
        test.registerTempTable("test");

        sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show();

        sc.stop();
    }

}

这样SQL简洁明了,就能表达意思了。

449064-20170223234649132-2040239169.png

参考

本文转自博客园xingoo的博客,原文链接:Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF 教程(Java踩坑教学版),如需转载请自行联系原博主。
相关文章
|
4天前
|
Web App开发 JavaScript 前端开发
《手把手教你》系列技巧篇(三十九)-java+ selenium自动化测试-JavaScript的调用执行-上篇(详解教程)
【5月更文挑战第3天】本文介绍了如何在Web自动化测试中使用JavaScript执行器(JavascriptExecutor)来完成Selenium API无法处理的任务。首先,需要将WebDriver转换为JavascriptExecutor对象,然后通过executeScript方法执行JavaScript代码。示例用法包括设置JS代码字符串并调用executeScript。文章提供了两个实战场景:一是当时间插件限制输入时,用JS去除元素的readonly属性;二是处理需滚动才能显示的元素,利用JS滚动页面。还给出了一个滚动到底部的代码示例,并提供了详细步骤和解释。
28 10
|
12天前
|
Java 测试技术 Python
《手把手教你》系列技巧篇(三十六)-java+ selenium自动化测试-单选和多选按钮操作-番外篇(详解教程)
【4月更文挑战第28天】本文简要介绍了自动化测试的实战应用,通过一个在线问卷调查(&lt;https://www.sojump.com/m/2792226.aspx/&gt;)为例,展示了如何遍历并点击问卷中的选项。测试思路包括找到单选和多选按钮的共性以定位元素,然后使用for循环进行点击操作。代码设计方面,提供了Java+Selenium的示例代码,通过WebDriver实现自动答题。运行代码后,可以看到控制台输出和浏览器的相应动作。文章最后做了简单的小结,强调了本次实践是对之前单选多选操作的巩固。
24 0
|
1天前
|
前端开发 Java 测试技术
《手把手教你》系列技巧篇(四十二)-java+ selenium自动化测试 - 处理iframe -下篇(详解教程)
【5月更文挑战第6天】本文介绍了如何使用Selenium处理含有iframe的网页。作者首先解释了iframe是什么,即HTML中的一个框架,用于在一个页面中嵌入另一个页面。接着,通过一个实战例子展示了在QQ邮箱登录页面中,由于输入框存在于iframe内,导致直接定位元素失败。作者提供了三种方法来处理这种情况:1)通过id或name属性切换到iframe;2)使用webElement对象切换;3)通过索引切换。最后,给出了相应的Java代码示例,并提醒读者根据iframe的实际情况选择合适的方法进行切换和元素定位。
7 0
|
2天前
|
前端开发 测试技术 Python
《手把手教你》系列技巧篇(四十一)-java+ selenium自动化测试 - 处理iframe -上篇(详解教程)
【5月更文挑战第5天】本文介绍了HTML中的`iframe`标签,它用于在网页中嵌套其他网页。`iframe`常用于加载外部内容或网站的某个部分,以实现页面美观。文章还讲述了使用Selenium自动化测试时如何处理`iframe`,通过`switchTo().frame()`方法进入`iframe`,完成相应操作,然后使用`switchTo().defaultContent()`返回主窗口。此外,文章提供了一个包含`iframe`的HTML代码示例,并给出了一个简单的自动化测试代码实战,演示了如何在`iframe`中输入文本。
14 3
|
3天前
|
JavaScript 前端开发 Java
《手把手教你》系列技巧篇(四十)-java+ selenium自动化测试-JavaScript的调用执行-下篇(详解教程)
【5月更文挑战第4天】本文介绍了如何使用JavaScriptExecutor在自动化测试中实现元素高亮显示。通过创建并执行JS代码,可以改变元素的样式,例如设置背景色和边框,以突出显示被操作的元素。文中提供了一个Java示例,展示了如何在Selenium中使用此方法,并附有代码截图和运行效果展示。该技术有助于跟踪和理解测试过程中的元素交互。
8 0
|
5天前
|
JavaScript 前端开发 测试技术
《手把手教你》系列技巧篇(三十八)-java+ selenium自动化测试-日历时间控件-下篇(详解教程)
【5月更文挑战第2天】在自动化测试过程中,经常会遇到处理日期控件的点击问题。宏哥之前分享过一种方法,但如果输入框是`readonly`属性,这种方法就无法奏效了。不过,通过修改元素属性,依然可以实现自动化填写日期。首先,定位到日期输入框并移除`readonly`属性,然后使用`sendKeys`方法输入日期。这样,即使输入框设置了`readonly`,也能成功处理日期控件。
24 1
|
6天前
|
Java 测试技术 Python
《手把手教你》系列技巧篇(三十七)-java+ selenium自动化测试-日历时间控件-上篇(详解教程)
【5月更文挑战第1天】该文介绍了使用Selenium自动化测试网页日历控件的方法。首先,文章提到在某些Web应用中,日历控件常用于选择日期并筛选数据。接着,它提供了两个实现思路:一是将日历视为文本输入框,直接输入日期;二是模拟用户交互,逐步选择日期。文中给出了JQueryUI网站的一个示例,并展示了对应的Java代码实现,包括点击日历、选择日期等操作。
26 0
|
12天前
|
SQL 分布式计算 Hadoop
【Spark】Spark基础教程知识点
【Spark】Spark基础教程知识点
|
4月前
|
机器学习/深度学习 SQL 分布式计算
Apache Spark 的基本概念和在大数据分析中的应用
介绍 Apache Spark 的基本概念和在大数据分析中的应用
162 0

热门文章

最新文章