StructuredStreaming自定义DataSourceV2源

本文涉及的产品
云数据库 RDS MySQL,集群系列 2核4GB
推荐场景:
搭建个人博客
RDS MySQL Serverless 基础系列,0.5-2RCU 50GB
云原生数据库 PolarDB MySQL 版,通用型 2核4GB 50GB
简介: StructuredStreaming自定义DataSourceV2源

自定义输入源

某些应用场景下我们可能需要自定义数据源,如业务中,需要在获取KafkaSource的同时,动态从缓存中或者http请求中加载业务数据,或者是其它的数据源等都可以参考规范自定义。自定义输入源需要以下步骤:

第一步:继承DataSourceRegister和StreamSourceProvider创建自定义Provider类

第二步:重写DataSourceRegister类中的shotName和StreamSourceProvider中的createSource以及sourceSchema方法

第三步:继承Source创建自定义Source类

第四步:重写Source中的schema方法指定输入源的schema

第五步:重写Source中的getOffest方法监听流数据

第六步:重写Source中的getBatch方法获取数据

第七步:重写Source中的stop方法用来关闭资源

创建CustomDataSourceProvider类

1:继承DataSourceRegister和StreamSourceProvider

2:重写DataSourceRegister的shotName方法

3:重写StreamSourceProvider中的sourceSchema方法

4:重写StreamSourceProvider中的createSource方法

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
/**
  *       自定义Structured Streaming数据源
  *
  *       (1)继承DataSourceRegister类
  *       需要重写shortName方法,用来向Spark注册该组件
  *       (2)继承StreamSourceProvider类
  *       需要重写createSource以及sourceSchema方法,用来创建数据输入源
  *       (3)继承StreamSinkProvider类
  *       需要重写createSink方法,用来创建数据输出源
  *
  *
  */
## 要创建自定义的DataSourceProvider必须要继承位于org.apache.spark.sql.sources包下的DataSourceRegister以及该包下的StreamSourceProvider
class CustomDataSourceProvider extends DataSourceRegister
  with StreamSourceProvider
  with StreamSinkProvider
  with Logging {
  /**
    * 数据源的描述名字,如:kafka、socket
    * @return 字符串shotName
    * 该方法用来指定一个数据源的名字,用来想spark注册该数据源。如Spark内置的数据源的shotName:kafka、socket、rate等,该方法返回一个字符串
    */
  override def shortName(): String = "custom"
  /**
    * 定义数据源的Schema
    *
    * @param sqlContext   Spark SQL 上下文
    * @param schema       通过.schema()方法传入的schema
    * @param providerName Provider的名称,包名+类名
    * @param parameters   通过.option()方法传入的参数
    * @return 元组,(shotName,schema)
    *
    */
    ## 该方法是用来定义数据源的schema,可以使用用户传入的schema,也可以根据传入的参数进行动态创建。返回值是个二元组(shotName,scheam)
  override def sourceSchema(sqlContext: SQLContext,
                            schema: Option[StructType],
                            providerName: String,
                            parameters: Map[String, String]): (String, StructType) = (shortName(),schema.get)
  /**
    * 创建输入源
    *
    * @param sqlContext   Spark SQL 上下文
    * @param metadataPath 元数据Path
    * @param schema       通过.schema()方法传入的schema
    * @param providerName Provider的名称,包名+类名
    * @param parameters   通过.option()方法传入的参数
    * @return 自定义source,需要继承Source接口实现
    **/
    ## 通过传入的参数,来实例化我们自定义的DataSource,是我们自定义Source的重要入口的地方
  override def createSource(sqlContext: SQLContext,
                            metadataPath: String,
                            schema: Option[StructType],
                            providerName: String,
                            parameters: Map[String, String]): Source = new CustomDataSource(sqlContext,parameters,schema)
  /**
    * 创建输出源
    *
    * @param sqlContext       Spark SQL 上下文
    * @param parameters       通过.option()方法传入的参数
    * @param partitionColumns 分区列名?
    * @param outputMode       输出模式
    * @return
    */
    ## 重写StreamSinkProvider中的createSink方法
  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = new CustomDataSink(sqlContext,parameters,outputMode)
}

创建CustomDataSource类

1:继承Source创建CustomDataSource类 2:重写Source的schema方法 3:重写Source的getOffset方法 4:重写Source的getBatch方法 5:重写Source的stop方法

package org.apache.spark.sql.structured.datasource.custom
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.{Offset, Source}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext}
/**
  *       自定义数据输入源:需要继承Source接口
  *       实现思路:
  *       (1)通过重写schema方法来指定数据输入源的schema,这个schema需要与Provider中指定的schema保持一致
  *       (2)通过重写getOffset方法来获取数据的偏移量,这个方法会一直被轮询调用,不断的获取偏移量
  *       (3) 通过重写getBatch方法,来获取数据,这个方法是在偏移量发生改变后被触发
  *       (4)通过stop方法,来进行一下关闭资源的操作
  *
  */
class CustomDataSource(sqlContext: SQLContext,
                       parameters: Map[String, String],
                       schemaOption: Option[StructType]) extends Source
  with Logging {
  /**
    * 指定数据源的schema,需要与Provider中sourceSchema中指定的schema保持一直,否则会报异常
    * 触发机制:当创建数据源的时候被触发执行
    *
    * @return schema
    */
  override def schema: StructType = schemaOption.get
  /**
    * 获取offset,用来监控数据的变化情况
    * 触发机制:不断轮询调用
    * 实现要点:
    * (1)Offset的实现:
    * 由函数返回值可以看出,我们需要提供一个标准的返回值Option[Offset]
    * 我们可以通过继承 org.apache.spark.sql.sources.v2.reader.streaming.Offset实现,这里面其实就是保存了个json字符串
    *
    * (2) JSON转化
    * 因为Offset里实现的是一个json字符串,所以我们需要将我们存放offset的集合或者case class转化重json字符串
    * spark里是通过org.json4s.jackson这个包来实现case class 集合类(Map、List、Seq、Set等)与json字符串的相互转化
    *
    * @return Offset
    */
  override def getOffset: Option[Offset] = ???
  /**
    * 获取数据
    *
    * @param start 上一个批次的end offset
    * @param end   通过getOffset获取的新的offset
    *              触发机制:当不断轮询的getOffset方法,获取的offset发生改变时,会触发该方法
    *
    *              实现要点:
    *              (1)DataFrame的创建:
    *              可以通过生成RDD,然后使用RDD创建DataFrame
    *              RDD创建:sqlContext.sparkContext.parallelize(rows.toSeq)
    *              DataFrame创建:sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
    * @return DataFrame
    */
  override def getBatch(start: Option[Offset], end: Offset): DataFrame = ???
  /**
    * 关闭资源
    * 将一些需要关闭的资源放到这里来关闭,如MySQL的数据库连接等
    */
  override def stop(): Unit = ???
}

自定义输出源

相比较输入源的自定义性,输出源自定义的应用场景貌似更为常用。比如:数据写入关系型数据库、数据写入HBase、数据写入Redis等等。其实Structured提供的foreach以及2.4版本的foreachBatch方法已经可以实现绝大数的应用场景的,几乎是数据想写到什么地方都能实现。但是想要更优雅的实现,我们可以参考Spark SQL Sink规范,通过自定义的Sink的方式来实现。实现自定义Sink需要以下四个个步骤:

第一步:继承DataSourceRegister和StreamSinkProvider创建自定义SinkProvider类

第二步:重写DataSourceRegister类中的shotName和StreamSinkProvider中的createSink方法
第三步:继承Sink创建自定义Sink类
第四步:重写Sink中的addBatch方法

创建CustomDataSink类

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.{DataFrame, SQLContext}
/**
  *       自定义数据输出源
  */
class CustomDataSink(sqlContext: SQLContext,
                     parameters: Map[String, String],
                     outputMode: OutputMode) extends Sink with Logging {
  /**
    * 添加Batch,即数据写出
    *
    * @param batchId batchId
    * @param data    DataFrame
    *                触发机制:当发生计算时,会触发该方法,并且得到要输出的DataFrame
    *                实现摘要:
    *                1. 数据写入方式:
    *                (1)通过SparkSQL内置的数据源写出
    *                我们拿到DataFrame之后可以通过SparkSQL内置的数据源将数据写出,如:
    *                JSON数据源、CSV数据源、Text数据源、Parquet数据源、JDBC数据源等。
    *                (2)通过自定义SparkSQL的数据源进行写出
    *                (3)通过foreachPartition 将数据写出
    */
  override def addBatch(batchId: Long, data: DataFrame): Unit = ???
}

实现mysql/pg自定义输入源输出源完整demo

1: 创建MySQLSourceProvider类

package org.apache.spark.sql.structured.datasource
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
class MySQLSourceProvider extends DataSourceRegister
  with StreamSourceProvider
  with StreamSinkProvider
  with Logging {
  /**
    * 数据源的描述名字,如:kafka、socket
    *
    * @return 字符串shotName
    */
  override def shortName(): String = "mysql"
  /**
    * 定义数据源的Schema
    *
    * @param sqlContext   Spark SQL 上下文
    * @param schema       通过.schema()方法传入的schema
    * @param providerName Provider的名称,包名+类名
    * @param parameters   通过.option()方法传入的参数
    * @return 元组,(shotName,schema)
    */
  override def sourceSchema(
                             sqlContext: SQLContext,
                             schema: Option[StructType],
                             providerName: String,
                             parameters: Map[String, String]): (String, StructType) = {
    (providerName, schema.get)
  }
  /**
    * 创建输入源
    *
    * @param sqlContext   Spark SQL 上下文
    * @param metadataPath 元数据Path
    * @param schema       通过.schema()方法传入的schema
    * @param providerName Provider的名称,包名+类名
    * @param parameters   通过.option()方法传入的参数
    * @return 自定义source,需要继承Source接口实现
    */
  override def createSource(
                             sqlContext: SQLContext,
                             metadataPath: String, schema: Option[StructType],
                             providerName: String, parameters: Map[String, String]): Source = new MySQLSource(sqlContext, parameters, schema)
  /**
    * 创建输出源
    *
    * @param sqlContext       Spark SQL 上下文
    * @param parameters       通过.option()方法传入的参数
    * @param partitionColumns 分区列名?
    * @param outputMode       输出模式
    * @return
    */
  override def createSink(
                           sqlContext: SQLContext,
                           parameters: Map[String, String],
                           partitionColumns: Seq[String], outputMode: OutputMode): Sink = new MySQLSink(sqlContext: SQLContext,parameters, outputMode)
}

2:创建MySQLSource类

package org.apache.spark.sql.structured.datasource
import java.sql.{Connection, ResultSet}
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.{Offset, Source}
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.unsafe.types.UTF8String
import org.json4s.jackson.Serialization
import org.json4s.{Formats, NoTypeHints}
class MySQLSource(sqlContext: SQLContext,
                  options: Map[String, String],
                  schemaOption: Option[StructType]) extends Source with Logging {
  lazy val conn: Connection = C3p0Utils.getDataSource(options).getConnection
  val tableName: String = options("tableName")
  var currentOffset: Map[String, Long] = Map[String, Long](tableName -> 0)
  val maxOffsetPerBatch: Option[Long] = Option(100)
  val inputMetrics = new InputMetrics()
  override def schema: StructType = schemaOption.get
  /**
   * 获取Offset
   * 这里监控MySQL数据库表中条数变化情况
   * @return Option[Offset]
   */
  override def getOffset: Option[Offset] = {
    val latest = getLatestOffset //获取表中数据条数
    val offsets = maxOffsetPerBatch match {
      case None => MySQLSourceOffset(latest)
      case Some(limit) =>
        MySQLSourceOffset(rateLimit(limit, currentOffset, latest))
    }
    Option(offsets)
  }
  /**
   * 获取数据
   * @param start 上一次的offset
   * @param end 最新的offset
   * @return df
   */
  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
    var offset: Long = 0
    if (start.isDefined) {
      offset = offset2Map(start.get)(tableName)
    }
    val limit = offset2Map(end)(tableName) - offset
    val sql = s"SELECT * FROM $tableName  limit $limit offset $offset" //
    val st = conn.prepareStatement(sql)
    val rs = st.executeQuery()
    val rows = getInternalRow1(rs)
    val rdd = sqlContext.sparkContext.parallelize(rows.toSeq)
    currentOffset = offset2Map(end)
    sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
  }
  def getInternalRow1(rs: ResultSet): Iterator[InternalRow] = {
    //获取StructType
    val schema: StructType = schemaOption.get
    val fields: Array[StructField] = schema.fields
    var rows: List[InternalRow] = List[InternalRow]()
    while (rs.next()) {
      var s = Seq[Any]()
      //遍历字段 根据字段类型 去获取对应的数据
      //注意类型  基本数值类型要转成包装器类型     string转成utf8
      for (field <- fields) {
        val value = field.dataType match {
          case IntegerType => rs.getInt(field.name)
          case StringType => UTF8String.fromString(rs.getString(field.name))
          case LongType => rs.getLong(field.name)
          case TimestampType => {
            val t: java.lang.Long = rs.getTimestamp(field.name) match {
              case null => null
              //时间戳在IntervalRow里面必须是微秒单位的
              case value => value.getTime * 1000
            }
            t
          }
          case DoubleType => rs.getDouble(field.name)
        }
        s = s.:+(value)
      }
      //      println(s)
      //将Seq转成InternalRow
      val internalRow = InternalRow.fromSeq(s)
      rows = rows.:+(internalRow)
    }
    rows.toIterator
  }
  override def stop(): Unit = {
    conn.close()
  }
  def rateLimit(limit: Long, currentOffset: Map[String, Long], latestOffset: Map[String, Long]): Map[String, Long] = {
    val co = currentOffset(tableName)
    val lo = latestOffset(tableName)
    if (co + limit > lo) {
      Map[String, Long](tableName -> lo)
    } else {
      Map[String, Long](tableName -> (co + limit))
    }
  }
  // 获取最新条数
  def getLatestOffset: Map[String, Long] = {
    var offset: Long = 0
    val sql = s"SELECT COUNT(1) FROM $tableName"
    val st = conn.prepareStatement(sql)
    val rs = st.executeQuery()
    while (rs.next()) {
      offset = rs.getLong(1)
    }
    Map[String, Long](tableName -> offset)
  }
  def offset2Map(offset: Offset): Map[String, Long] = {
    implicit val formats: AnyRef with Formats = Serialization.formats(NoTypeHints)
    Serialization.read[Map[String, Long]](offset.json())
  }
}
case class MySQLSourceOffset(offset: Map[String, Long]) extends Offset {
  implicit val formats: AnyRef with Formats = Serialization.formats(NoTypeHints)
  override def json(): String = Serialization.write(offset)
}

3:创建MySQLSink类

package org.apache.spark.sql.structured.datasource
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
class MySQLSink(sqlContext: SQLContext,parameters: Map[String, String], outputMode: OutputMode) extends Sink with Logging {
  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    val query = data.queryExecution
    val rdd = query.toRdd
    val df = sqlContext.internalCreateDataFrame(rdd, data.schema)
    df.show(false)
    df.write.format("jdbc").options(parameters).mode(SaveMode.Append).save()
  }
}

4: 运行主类

package org.apache.spark.sql.structured.datasource.onlineIndex
import java.sql.{Connection, DriverManager, PreparedStatement}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, ForeachWriter, SparkSession}
object TotalNumOfIncoming {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName(this.getClass.getSimpleName)
      .master("local[*]")
      .getOrCreate()
    val schema = StructType(List(
      StructField("tiff_id", IntegerType),
      StructField("order_id", IntegerType),
      StructField("user_name", StringType),
      StructField("tiff", StringType),
      StructField("tar_gz", StringType),
      StructField("directory", StringType),
      StructField("tiff_size", LongType),
      StructField("xml", StringType),
      StructField("rpb", StringType),
      StructField("product_level", StringType),
      StructField("satellite", StringType),
      StructField("sensor", StringType),
      StructField("receive_time", TimestampType),
      StructField("upload_time", TimestampType),
      StructField("bands", StringType), //将ArrayType(StringType)  => StringType  后面读取数据时如果是Array会报pg序列化异常
      StructField("top_left_longitude", DoubleType),
      StructField("top_left_latitude", DoubleType),
      StructField("bottom_right_longitude", DoubleType),
      StructField("bottom_right_latitude", DoubleType),
      StructField("geom", StringType),
      StructField("pixel_spacing", IntegerType),
      StructField("adcode", StringType),
      StructField("cloud_cover", DoubleType),
      StructField("province", StringType),
      StructField("city", StringType),
      StructField("district", StringType),
      StructField("status", IntegerType),
      StructField("downNum", IntegerType)
    )
    )
    val options = Map[String, String](
      "driverClass" -> "org.postgresql.Driver",
      "jdbcUrl" -> "jdbc:postgresql://192.168.0.214:25433/test1",
      "user" -> "postgres",
      "password" -> "bjsh",
      "tableName" -> "tiff_entry_test")
    val source: DataFrame = spark
      .readStream
      .format("org.apache.spark.sql.structured.datasource.MySQLSourceProvider")
      .options(options)
      .schema(schema)
      .load()
    import org.apache.spark.sql.functions._
    source.createOrReplaceTempView("tiff_entry_test")
    //总入库
    val res1: DataFrame = spark.sql("select count(1) ct from tiff_entry_test")
    res1.toJSON.withColumn("index",lit("总入库数"))
      .select("index","ct")
      .writeStream
      .outputMode("update")
      .format("org.apache.spark.sql.structured.datasource.MySQLSourceProvider")
      .option("truncate", "true")
      .option("checkpointLocation", "./tmp/MySQLSourceProvider11")
      .option("user", "postgres")
      .option("password", "bjsh")
      .option("dbtable", "result")
      .option("url", "jdbc:postgresql://192.168.0.214:25433/test1")
      .start()
     .awaitTermination()
  }
}
相关文章
|
8月前
|
边缘计算 安全 网络安全
隐藏服务器源IP怎么操作?
一篇文章看懂隐藏源IP!
162 0
|
Java Python Windows
Python pip 源设置成国内源,阿里云源,清华大学源,最方便的方式,都在这里了
Python pip 源设置成国内源,阿里云源,清华大学源,最方便的方式,都在这里了
35520 0
|
网络协议 应用服务中间件 网络安全
限定源端口访问目标
限定源端口访问目标 1.1. 起因 在渗透测试时,客户需要对我们的测试IP进行加白,但是此次客户要求精确到固定端口或者小范围端口(不能1-65535),根据以前的经验,默认是加白IP和全端口,因为代理建立连接使用的端口是随机的,所以这次算是从头查找资料总结一下各种指定源端口的方式。 这里的端口是指与目标建立连接时使用的源端口,而不是代理监听的端口。 1.2. 注意 最
6591 1
|
缓存 Linux
配置网络源仓库
配置网络源仓库
365 0
|
Linux Windows
修改默认的pip安装源
修改默认的pip安装源
569 0
日志服务数据加工:源与目标访问秘钥配置
日志服务数据加工上线,介绍了配置数据数据加工时从源logstore到目标logstore进行分发的权限配置细节与样例
5216 0
|
Windows
批量创建IP方法
以下教程,将告诉大家如何在Windows系统中通过命令行,批量添加IP。目标,在本机的的网卡名称为“本地连接”的网卡中,批量添加192.168.1段的ip地址,起开始IP为10,每次增加1,知道22为止,即括号的三个参数。
1012 0
WordPress发布文章/页面时自动添加默认的自定义字段
如果你每篇文章或页面都需要插入同一个自定义字段和值,可以考虑在WordPress发布文章/页面时,自动添加默认的自定义字段
1487 0