自定义输入源
某些应用场景下我们可能需要自定义数据源,如业务中,需要在获取KafkaSource的同时,动态从缓存中或者http请求中加载业务数据,或者是其它的数据源等都可以参考规范自定义。自定义输入源需要以下步骤:
第一步:继承DataSourceRegister和StreamSourceProvider创建自定义Provider类
第二步:重写DataSourceRegister类中的shotName和StreamSourceProvider中的createSource以及sourceSchema方法
第三步:继承Source创建自定义Source类
第四步:重写Source中的schema方法指定输入源的schema
第五步:重写Source中的getOffest方法监听流数据
第六步:重写Source中的getBatch方法获取数据
第七步:重写Source中的stop方法用来关闭资源
创建CustomDataSourceProvider类
1:继承DataSourceRegister和StreamSourceProvider
2:重写DataSourceRegister的shotName方法
3:重写StreamSourceProvider中的sourceSchema方法
4:重写StreamSourceProvider中的createSource方法
import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** * 自定义Structured Streaming数据源 * * (1)继承DataSourceRegister类 * 需要重写shortName方法,用来向Spark注册该组件 * (2)继承StreamSourceProvider类 * 需要重写createSource以及sourceSchema方法,用来创建数据输入源 * (3)继承StreamSinkProvider类 * 需要重写createSink方法,用来创建数据输出源 * * */ ## 要创建自定义的DataSourceProvider必须要继承位于org.apache.spark.sql.sources包下的DataSourceRegister以及该包下的StreamSourceProvider class CustomDataSourceProvider extends DataSourceRegister with StreamSourceProvider with StreamSinkProvider with Logging { /** * 数据源的描述名字,如:kafka、socket * @return 字符串shotName * 该方法用来指定一个数据源的名字,用来想spark注册该数据源。如Spark内置的数据源的shotName:kafka、socket、rate等,该方法返回一个字符串 */ override def shortName(): String = "custom" /** * 定义数据源的Schema * * @param sqlContext Spark SQL 上下文 * @param schema 通过.schema()方法传入的schema * @param providerName Provider的名称,包名+类名 * @param parameters 通过.option()方法传入的参数 * @return 元组,(shotName,schema) * */ ## 该方法是用来定义数据源的schema,可以使用用户传入的schema,也可以根据传入的参数进行动态创建。返回值是个二元组(shotName,scheam) override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = (shortName(),schema.get) /** * 创建输入源 * * @param sqlContext Spark SQL 上下文 * @param metadataPath 元数据Path * @param schema 通过.schema()方法传入的schema * @param providerName Provider的名称,包名+类名 * @param parameters 通过.option()方法传入的参数 * @return 自定义source,需要继承Source接口实现 **/ ## 通过传入的参数,来实例化我们自定义的DataSource,是我们自定义Source的重要入口的地方 override def createSource(sqlContext: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = new CustomDataSource(sqlContext,parameters,schema) /** * 创建输出源 * * @param sqlContext Spark SQL 上下文 * @param parameters 通过.option()方法传入的参数 * @param partitionColumns 分区列名? * @param outputMode 输出模式 * @return */ ## 重写StreamSinkProvider中的createSink方法 override def createSink(sqlContext: SQLContext, parameters: Map[String, String], partitionColumns: Seq[String], outputMode: OutputMode): Sink = new CustomDataSink(sqlContext,parameters,outputMode) }
创建CustomDataSource类
1:继承Source创建CustomDataSource类 2:重写Source的schema方法 3:重写Source的getOffset方法 4:重写Source的getBatch方法 5:重写Source的stop方法
package org.apache.spark.sql.structured.datasource.custom import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.{Offset, Source} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext} /** * 自定义数据输入源:需要继承Source接口 * 实现思路: * (1)通过重写schema方法来指定数据输入源的schema,这个schema需要与Provider中指定的schema保持一致 * (2)通过重写getOffset方法来获取数据的偏移量,这个方法会一直被轮询调用,不断的获取偏移量 * (3) 通过重写getBatch方法,来获取数据,这个方法是在偏移量发生改变后被触发 * (4)通过stop方法,来进行一下关闭资源的操作 * */ class CustomDataSource(sqlContext: SQLContext, parameters: Map[String, String], schemaOption: Option[StructType]) extends Source with Logging { /** * 指定数据源的schema,需要与Provider中sourceSchema中指定的schema保持一直,否则会报异常 * 触发机制:当创建数据源的时候被触发执行 * * @return schema */ override def schema: StructType = schemaOption.get /** * 获取offset,用来监控数据的变化情况 * 触发机制:不断轮询调用 * 实现要点: * (1)Offset的实现: * 由函数返回值可以看出,我们需要提供一个标准的返回值Option[Offset] * 我们可以通过继承 org.apache.spark.sql.sources.v2.reader.streaming.Offset实现,这里面其实就是保存了个json字符串 * * (2) JSON转化 * 因为Offset里实现的是一个json字符串,所以我们需要将我们存放offset的集合或者case class转化重json字符串 * spark里是通过org.json4s.jackson这个包来实现case class 集合类(Map、List、Seq、Set等)与json字符串的相互转化 * * @return Offset */ override def getOffset: Option[Offset] = ??? /** * 获取数据 * * @param start 上一个批次的end offset * @param end 通过getOffset获取的新的offset * 触发机制:当不断轮询的getOffset方法,获取的offset发生改变时,会触发该方法 * * 实现要点: * (1)DataFrame的创建: * 可以通过生成RDD,然后使用RDD创建DataFrame * RDD创建:sqlContext.sparkContext.parallelize(rows.toSeq) * DataFrame创建:sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) * @return DataFrame */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = ??? /** * 关闭资源 * 将一些需要关闭的资源放到这里来关闭,如MySQL的数据库连接等 */ override def stop(): Unit = ??? }
自定义输出源
相比较输入源的自定义性,输出源自定义的应用场景貌似更为常用。比如:数据写入关系型数据库、数据写入HBase、数据写入Redis等等。其实Structured提供的foreach以及2.4版本的foreachBatch方法已经可以实现绝大数的应用场景的,几乎是数据想写到什么地方都能实现。但是想要更优雅的实现,我们可以参考Spark SQL Sink规范,通过自定义的Sink的方式来实现。实现自定义Sink需要以下四个个步骤:
第一步:继承DataSourceRegister和StreamSinkProvider创建自定义SinkProvider类
第二步:重写DataSourceRegister类中的shotName和StreamSinkProvider中的createSink方法
第三步:继承Sink创建自定义Sink类
第四步:重写Sink中的addBatch方法
创建CustomDataSink类
import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.{DataFrame, SQLContext} /** * 自定义数据输出源 */ class CustomDataSink(sqlContext: SQLContext, parameters: Map[String, String], outputMode: OutputMode) extends Sink with Logging { /** * 添加Batch,即数据写出 * * @param batchId batchId * @param data DataFrame * 触发机制:当发生计算时,会触发该方法,并且得到要输出的DataFrame * 实现摘要: * 1. 数据写入方式: * (1)通过SparkSQL内置的数据源写出 * 我们拿到DataFrame之后可以通过SparkSQL内置的数据源将数据写出,如: * JSON数据源、CSV数据源、Text数据源、Parquet数据源、JDBC数据源等。 * (2)通过自定义SparkSQL的数据源进行写出 * (3)通过foreachPartition 将数据写出 */ override def addBatch(batchId: Long, data: DataFrame): Unit = ??? }
实现mysql/pg自定义输入源输出源完整demo
1: 创建MySQLSourceProvider类
package org.apache.spark.sql.structured.datasource import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType class MySQLSourceProvider extends DataSourceRegister with StreamSourceProvider with StreamSinkProvider with Logging { /** * 数据源的描述名字,如:kafka、socket * * @return 字符串shotName */ override def shortName(): String = "mysql" /** * 定义数据源的Schema * * @param sqlContext Spark SQL 上下文 * @param schema 通过.schema()方法传入的schema * @param providerName Provider的名称,包名+类名 * @param parameters 通过.option()方法传入的参数 * @return 元组,(shotName,schema) */ override def sourceSchema( sqlContext: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = { (providerName, schema.get) } /** * 创建输入源 * * @param sqlContext Spark SQL 上下文 * @param metadataPath 元数据Path * @param schema 通过.schema()方法传入的schema * @param providerName Provider的名称,包名+类名 * @param parameters 通过.option()方法传入的参数 * @return 自定义source,需要继承Source接口实现 */ override def createSource( sqlContext: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = new MySQLSource(sqlContext, parameters, schema) /** * 创建输出源 * * @param sqlContext Spark SQL 上下文 * @param parameters 通过.option()方法传入的参数 * @param partitionColumns 分区列名? * @param outputMode 输出模式 * @return */ override def createSink( sqlContext: SQLContext, parameters: Map[String, String], partitionColumns: Seq[String], outputMode: OutputMode): Sink = new MySQLSink(sqlContext: SQLContext,parameters, outputMode) }
2:创建MySQLSource类
package org.apache.spark.sql.structured.datasource import java.sql.{Connection, ResultSet} import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.{Offset, Source} import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.unsafe.types.UTF8String import org.json4s.jackson.Serialization import org.json4s.{Formats, NoTypeHints} class MySQLSource(sqlContext: SQLContext, options: Map[String, String], schemaOption: Option[StructType]) extends Source with Logging { lazy val conn: Connection = C3p0Utils.getDataSource(options).getConnection val tableName: String = options("tableName") var currentOffset: Map[String, Long] = Map[String, Long](tableName -> 0) val maxOffsetPerBatch: Option[Long] = Option(100) val inputMetrics = new InputMetrics() override def schema: StructType = schemaOption.get /** * 获取Offset * 这里监控MySQL数据库表中条数变化情况 * @return Option[Offset] */ override def getOffset: Option[Offset] = { val latest = getLatestOffset //获取表中数据条数 val offsets = maxOffsetPerBatch match { case None => MySQLSourceOffset(latest) case Some(limit) => MySQLSourceOffset(rateLimit(limit, currentOffset, latest)) } Option(offsets) } /** * 获取数据 * @param start 上一次的offset * @param end 最新的offset * @return df */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = { var offset: Long = 0 if (start.isDefined) { offset = offset2Map(start.get)(tableName) } val limit = offset2Map(end)(tableName) - offset val sql = s"SELECT * FROM $tableName limit $limit offset $offset" // val st = conn.prepareStatement(sql) val rs = st.executeQuery() val rows = getInternalRow1(rs) val rdd = sqlContext.sparkContext.parallelize(rows.toSeq) currentOffset = offset2Map(end) sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } def getInternalRow1(rs: ResultSet): Iterator[InternalRow] = { //获取StructType val schema: StructType = schemaOption.get val fields: Array[StructField] = schema.fields var rows: List[InternalRow] = List[InternalRow]() while (rs.next()) { var s = Seq[Any]() //遍历字段 根据字段类型 去获取对应的数据 //注意类型 基本数值类型要转成包装器类型 string转成utf8 for (field <- fields) { val value = field.dataType match { case IntegerType => rs.getInt(field.name) case StringType => UTF8String.fromString(rs.getString(field.name)) case LongType => rs.getLong(field.name) case TimestampType => { val t: java.lang.Long = rs.getTimestamp(field.name) match { case null => null //时间戳在IntervalRow里面必须是微秒单位的 case value => value.getTime * 1000 } t } case DoubleType => rs.getDouble(field.name) } s = s.:+(value) } // println(s) //将Seq转成InternalRow val internalRow = InternalRow.fromSeq(s) rows = rows.:+(internalRow) } rows.toIterator } override def stop(): Unit = { conn.close() } def rateLimit(limit: Long, currentOffset: Map[String, Long], latestOffset: Map[String, Long]): Map[String, Long] = { val co = currentOffset(tableName) val lo = latestOffset(tableName) if (co + limit > lo) { Map[String, Long](tableName -> lo) } else { Map[String, Long](tableName -> (co + limit)) } } // 获取最新条数 def getLatestOffset: Map[String, Long] = { var offset: Long = 0 val sql = s"SELECT COUNT(1) FROM $tableName" val st = conn.prepareStatement(sql) val rs = st.executeQuery() while (rs.next()) { offset = rs.getLong(1) } Map[String, Long](tableName -> offset) } def offset2Map(offset: Offset): Map[String, Long] = { implicit val formats: AnyRef with Formats = Serialization.formats(NoTypeHints) Serialization.read[Map[String, Long]](offset.json()) } } case class MySQLSourceOffset(offset: Map[String, Long]) extends Offset { implicit val formats: AnyRef with Formats = Serialization.formats(NoTypeHints) override def json(): String = Serialization.write(offset) }
3:创建MySQLSink类
package org.apache.spark.sql.structured.datasource import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} class MySQLSink(sqlContext: SQLContext,parameters: Map[String, String], outputMode: OutputMode) extends Sink with Logging { override def addBatch(batchId: Long, data: DataFrame): Unit = { val query = data.queryExecution val rdd = query.toRdd val df = sqlContext.internalCreateDataFrame(rdd, data.schema) df.show(false) df.write.format("jdbc").options(parameters).mode(SaveMode.Append).save() } }
4: 运行主类
package org.apache.spark.sql.structured.datasource.onlineIndex import java.sql.{Connection, DriverManager, PreparedStatement} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, ForeachWriter, SparkSession} object TotalNumOfIncoming { def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName(this.getClass.getSimpleName) .master("local[*]") .getOrCreate() val schema = StructType(List( StructField("tiff_id", IntegerType), StructField("order_id", IntegerType), StructField("user_name", StringType), StructField("tiff", StringType), StructField("tar_gz", StringType), StructField("directory", StringType), StructField("tiff_size", LongType), StructField("xml", StringType), StructField("rpb", StringType), StructField("product_level", StringType), StructField("satellite", StringType), StructField("sensor", StringType), StructField("receive_time", TimestampType), StructField("upload_time", TimestampType), StructField("bands", StringType), //将ArrayType(StringType) => StringType 后面读取数据时如果是Array会报pg序列化异常 StructField("top_left_longitude", DoubleType), StructField("top_left_latitude", DoubleType), StructField("bottom_right_longitude", DoubleType), StructField("bottom_right_latitude", DoubleType), StructField("geom", StringType), StructField("pixel_spacing", IntegerType), StructField("adcode", StringType), StructField("cloud_cover", DoubleType), StructField("province", StringType), StructField("city", StringType), StructField("district", StringType), StructField("status", IntegerType), StructField("downNum", IntegerType) ) ) val options = Map[String, String]( "driverClass" -> "org.postgresql.Driver", "jdbcUrl" -> "jdbc:postgresql://192.168.0.214:25433/test1", "user" -> "postgres", "password" -> "bjsh", "tableName" -> "tiff_entry_test") val source: DataFrame = spark .readStream .format("org.apache.spark.sql.structured.datasource.MySQLSourceProvider") .options(options) .schema(schema) .load() import org.apache.spark.sql.functions._ source.createOrReplaceTempView("tiff_entry_test") //总入库 val res1: DataFrame = spark.sql("select count(1) ct from tiff_entry_test") res1.toJSON.withColumn("index",lit("总入库数")) .select("index","ct") .writeStream .outputMode("update") .format("org.apache.spark.sql.structured.datasource.MySQLSourceProvider") .option("truncate", "true") .option("checkpointLocation", "./tmp/MySQLSourceProvider11") .option("user", "postgres") .option("password", "bjsh") .option("dbtable", "result") .option("url", "jdbc:postgresql://192.168.0.214:25433/test1") .start() .awaitTermination() } }