sparksql 自定义DataSourceV2源

简介: sparksql 自定义DataSourceV2源

基于sparkSql DataSourceV2实现输入源SparkSQL的DataSourceV2的实现与StructuredStreaming自定义数据源如出一辙,思想是一样的,但是具体实现有所不同,主要步骤如下:

第一步:继承DataSourceV2和ReadSupport创建XXXDataSource类,重写ReadSupport的creatReader方法,用来返回自定义的DataSourceReader类,如返回自定义XXXDataSourceReader实例
第二步:继承DataSourceReader创建XXXDataSourceReader类,重写DataSourceReader的readSchema方法用来返回数据源的schema,重写DataSourceReader的createDataReaderFactories用来返回多个自定义DataReaderFactory实例
第三步:继承DataReaderFactory创建DataReader工厂类,如XXXDataReaderFactory,重写DataReaderFactory的createDataReader方法,返回自定义DataRader实例
第四步:继承DataReader类创建自定义的DataReader,如XXXDataReader,重写DataReader的next()方法,用来告诉Spark是否有下条数据,用来触发get()方法,重写DataReader的get()方法获取数据,重写DataReader的close()方法用来关闭资源

基于DataSourceV2实现输出源

基于DataSourceV2实现自定义的输出源,需要以下几个步骤:

第一步:继承DataSourceV2和WriteSupport创建XXXDataSource,重写createWriter方法用来返回自定义的DataSourceWriter
第二步:继承DataSourceWriter创建XXXDataSourceWriter类,重写createWriterFactory返回自定义的DataWriterFactory,重写commit方法,用来提交整个事务。重写abort方法,用来做事务回滚
第三步:继承DataWriterFactory创建XXXDataWriterFactory类,重写createWriter方法返回自定义的DataWriter
第四步:继承DataWriter创建XXXDataWriter类,重写write方法,用来将数据写出,重写commit方法用来提交事务,重写abort方法用来做事务回滚

SparkSql自定义RestDataSource代码

package com.hollysys.spark.sql.datasource.rest
import java.math.BigDecimal
import java.util
import java.util.Optional
import com.alibaba.fastjson.{JSONArray, JSONObject, JSONPath}
import org.apache.http.client.fluent.Request
import org.apache.http.entity.ContentType
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.types.StructType
/**
  *       基于Rest的Spark SQL DataSource
  */
class RestDataSource extends DataSourceV2 with ReadSupport with WriteSupport {
  override def createReader(options: DataSourceOptions): DataSourceReader =
    new RestDataSourceReader(
      options.get("url").get(),
      options.get("params").get(),
      options.get("xPath").get(),
      options.get("schema").get()
    )
  override def createWriter(jobId: String,
                            schema: StructType,
                            mode: SaveMode,
                            options: DataSourceOptions): Optional[DataSourceWriter] = Optional.of(new RestDataSourceWriter)
}
/**
  * 创建RestDataSourceReader
  *
  * @param url          REST服务的的api
  * @param params       请求需要的参数
  * @param xPath        JSON数据的xPath
  * @param schemaString 用户传入的schema字符串
  */
class RestDataSourceReader(url: String, params: String, xPath: String, schemaString: String)
  extends DataSourceReader {
  // 使用StructType.fromDDL方法将schema字符串转成StructType类型
  var requiredSchema: StructType = StructType.fromDDL(schemaString)
  /**
    * 生成schema
    *
    * @return schema
    */
  override def readSchema(): StructType = requiredSchema
  /**
    * 创建工厂类
    *
    * @return List[实例]
    */
  override def createDataReaderFactories(): util.List[DataReaderFactory[Row]] = {
    import collection.JavaConverters._
    Seq(
      new RestDataReaderFactory(url, params, xPath).asInstanceOf[DataReaderFactory[Row]]
    ).asJava
  }
}
/**
  * RestDataReaderFactory工厂类
  *
  * @param url    REST服务的的api
  * @param params 请求需要的参数
  * @param xPath  JSON数据的xPath
  */
class RestDataReaderFactory(url: String, params: String, xPath: String) extends DataReaderFactory[Row] {
  override def createDataReader(): DataReader[Row] = new RestDataReader(url, params, xPath)
}
/**
  * RestDataReader类
  *
  * @param url    REST服务的的api
  * @param params 请求需要的参数
  * @param xPath  JSON数据的xPath
  */
class RestDataReader(url: String, params: String, xPath: String) extends DataReader[Row] {
  // 使用Iterator模拟数据
  val data: Iterator[Seq[AnyRef]] = getIterator
  override def next(): Boolean = {
    data.hasNext
  }
  override def get(): Row = {
    val seq = data.next().map {
      // 浮点类型会自动转为BigDecimal,导致Spark无法转换
      case decimal: BigDecimal =>
        decimal.doubleValue()
      case x => x
    }
    Row(seq: _*)
  }
  override def close(): Unit = {
    println("close source")
  }
  def getIterator: Iterator[Seq[AnyRef]] = {
    import scala.collection.JavaConverters._
    val res: List[AnyRef] = RestDataSource.requestData(url, params, xPath)
    res.map(r => {
      r.asInstanceOf[JSONObject].asScala.values.toList
    }).toIterator
  }
}
/** *
  * RestDataSourceWriter
  */
class RestDataSourceWriter extends DataSourceWriter {
  /**
    * 创建RestDataWriter工厂类
    *
    * @return RestDataWriterFactory
    */
  override def createWriterFactory(): DataWriterFactory[Row] = new RestDataWriterFactory
  /**
    * commit
    *
    * @param messages 所有分区提交的commit信息
    *                 触发一次
    */
  override def commit(messages: Array[WriterCommitMessage]): Unit = ???
  /** *
    * abort
    *
    * @param messages 当write异常时调用
    */
  override def abort(messages: Array[WriterCommitMessage]): Unit = ???
}
/**
  * DataWriterFactory工厂类
  */
class RestDataWriterFactory extends DataWriterFactory[Row] {
  /**
    * 创建DataWriter
    *
    * @param partitionId   分区ID
    * @param attemptNumber 重试次数
    * @return DataWriter
    *         每个分区创建一个RestDataWriter实例
    */
  override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = new RestDataWriter(partitionId, attemptNumber)
}
/**
  * RestDataWriter
  *
  * @param partitionId   分区ID
  * @param attemptNumber 重试次数
  */
class RestDataWriter(partitionId: Int, attemptNumber: Int) extends DataWriter[Row] {
  /**
    * write
    *
    * @param record 单条记录
    *               每条记录都会触发该方法
    */
  override def write(record: Row): Unit = {
    println(record)
  }
  /**
    * commit
    *
    * @return commit message
    *         每个分区触发一次
    */
  override def commit(): WriterCommitMessage = {
    RestWriterCommitMessage(partitionId, attemptNumber)
  }
  /**
    * 回滚:当write发生异常时触发该方法
    */
  override def abort(): Unit = {
    println("abort 方法被出发了")
  }
}
case class RestWriterCommitMessage(partitionId: Int, attemptNumber: Int) extends WriterCommitMessage
object RestDataSource {
  def requestData(url: String, params: String, xPath: String): List[AnyRef] = {
    import scala.collection.JavaConverters._
    val response = Request.Post(url).bodyString(params, ContentType.APPLICATION_JSON).execute()
    JSONPath.read(response.returnContent().asString(), xPath)
      .asInstanceOf[JSONArray].asScala.toList
  }
}
object RestDataSourceTest {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .getOrCreate()
    val df = spark.read
      .format("com.hollysys.spark.sql.datasource.rest.RestDataSource")
      .option("url", "http://model-opcua-hollysysdigital-test.hiacloud.net.cn/aggquery/query/queryPointHistoryData")
      .option("params", "{\n    \"startTime\": \"1543887720000\",\n    \"endTime\": \"1543891320000\",\n    \"maxSizePerNode\": 1000,\n    \"nodes\": [\n        {\n            \"uri\": \"/SymLink-10000012030100000-device/5c174da007a54e0001035ddd\"\n        }\n    ]\n}")
      .option("xPath", "$.result.historyData")
      //`response` ARRAY<STRUCT<`historyData`:ARRAY<STRUCT<`s`:INT,`t`:LONG,`v`:FLOAT>>>>
      .option("schema", "`s` INT,`t` LONG,`v` DOUBLE")
      .load()
    df.printSchema()
    df.show(false)
  }
}

SparkSql自定义HBaseSource

import java.util
import java.util.Optional
import com.travel.utils.HbaseTools
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
object HBaseSourceAndSink {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[2]")
      .getOrCreate()
    val df = spark.read
      .format("com.travel.programApp.HBaseSource")
      .option("hbase.table.name", "spark_hbase_sql")
      .option("schema", "`name` STRING,`score` STRING")
      .option("cf.cc","cf:name,cf:score")
      .load()
    df.explain(true)
    df.createOrReplaceTempView("sparkHBaseSQL")
    df.printSchema()
    val frame: DataFrame = spark.sql("select * from sparkHBaseSQL where score > 60")
    frame.write.format("com.travel.programApp.HBaseSource")
      .mode(SaveMode.Overwrite)
      .option("hbase.table.name","spark_hbase_write")
      .save()
  }
}
class HBaseSource extends DataSourceV2 with ReadSupport with WriteSupport{
  override def createReader(options: DataSourceOptions): DataSourceReader = {
    new HBaseDataSourceReader(options.get("hbase.table.name").get(),options.get("schema").get(),options.get("cf.cc").get())
  }
  override def createWriter(jobId: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = {
    Optional.of(new HBaseDataSourceWrite)
  }
}
class HBaseDataSourceWrite extends DataSourceWriter{
  override def createWriterFactory(): DataWriterFactory[Row] = {
    new HBaseDataWriterFactory
  }
  override def commit(messages: Array[WriterCommitMessage]): Unit = {
  }
  override def abort(messages: Array[WriterCommitMessage]): Unit = {
  }
}
class HBaseDataWriterFactory extends DataWriterFactory[Row]{
  override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
    new HBaseDataWriter
  }
}
class HBaseDataWriter extends DataWriter[Row]{
  private val conn: Connection = HbaseTools.getHbaseConn
  private val table: Table = conn.getTable(TableName.valueOf("spark_hbase_write"))
  override def write(record: Row): Unit = {
    val name: String = record.getString(0)
    val score: String = record.getString(1)
    val put = new Put("0001".getBytes())
    put.addColumn("cf".getBytes(),"name".getBytes(),name.getBytes())
    put.addColumn("cf".getBytes(),"score".getBytes(),score.getBytes())
    table.put(put)
  }
  override def commit(): WriterCommitMessage = {
    table.close()
    conn.close()
    null
  }
  override def abort(): Unit = {
    null
  }
}
class HBaseDataSourceReader(tableName:String,schema:String,cfcc:String) extends DataSourceReader  {
  //定义HBase的schema
  private val structType: StructType = StructType.fromDDL(schema)
  override def readSchema(): StructType = {
    structType
  }
  //返回DataReaderFactory
  override def createDataReaderFactories(): util.List[DataReaderFactory[Row]] = {
    import collection.JavaConverters._
    Seq(
    new HBaseReaderFactory(tableName,cfcc).asInstanceOf[DataReaderFactory[Row]]
    ).asJava
  }
}
class HBaseReaderFactory(tableName:String,cfcc:String) extends  DataReaderFactory[Row] {
  override def createDataReader(): DataReader[Row] = {
    new HBaseReader(tableName,cfcc)
  }
}
class HBaseReader(tableName:String,cfcc:String) extends DataReader[Row] {
  private var hbaseConnection:Connection = null
  private var  resultScanner:ResultScanner = null
  private var nextResult:Result  = null
  // 获取HBase当中的数
  val data: Iterator[Seq[AnyRef]] = getIterator
  def getIterator: Iterator[Seq[AnyRef]] = {
    import scala.collection.JavaConverters._
    hbaseConnection = HbaseTools.getHbaseConn
    val table: Table = hbaseConnection.getTable(TableName.valueOf(tableName))
    resultScanner = table.getScanner(new Scan())
    val iterator: Iterator[Seq[AnyRef]] = resultScanner.iterator().asScala.map(eachResult => {
      val str: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "name".getBytes()))
      val score: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "score".getBytes()))
      Seq(str, score)
    })
    iterator
  }
  override def next(): Boolean = {
    data.hasNext
  }
  override def get(): Row = {
    val seq: Seq[Any] = data.next()
    Row.fromSeq(seq)
  }
  override def close(): Unit = {
    hbaseConnection.close()
  }
}

sparksql 基于clickHouuse 扩展DataSourceV2源

package com.mengyao.spark.datasourcev2.ext.example1
import java.io.Serializable
import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Statement}
import java.text.SimpleDateFormat
import java.util
import java.util.Optional
import cn.itcast.logistics.etl.Configure
import org.apache.commons.lang3.{StringUtils, SystemUtils}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.sources.{DataSourceRegister, EqualTo, Filter}
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, StreamWriteSupport, WriteSupport}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{StructType, _}
import org.apache.spark.unsafe.types.UTF8String
import org.javatuples.Triplet
import ru.yandex.clickhouse.domain.ClickHouseDataType
import ru.yandex.clickhouse.response.{ClickHouseResultSet, ClickHouseResultSetMetaData}
import ru.yandex.clickhouse.settings.ClickHouseProperties
import ru.yandex.clickhouse.{ClickHouseConnection, ClickHouseDataSource, ClickHouseStatement}
import scala.collection.mutable.ArrayBuffer
/**
 * @ClassName CKTest
 * @Description 测试ClickHouse的DataSourceV2实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
object CKTest {
  private val APP_NAME: String = CKTest.getClass.getSimpleName
  private val master: String = "local[2]"
  def main(args: Array[String]) {
    if (SystemUtils.IS_OS_WINDOWS) System.setProperty("hadoop.home.dir", Configure.LOCAL_HADOOP_HOME)
    val spark = SparkSession.builder()
      .master(master)
      .appName(APP_NAME).getOrCreate();
    val df = spark.read.format(Configure.SPARK_CLICKHOUSE_FORMAT)
      .option("driver", Configure.clickhouseDriver)
      .option("url", Configure.clickhouseUrl)
      .option("user", Configure.clickhouseUser)
      .option("password", Configure.clickhousePassword)
      .option("table", "tbl_address")
      .option("use_server_time_zone", "false")
      .option("use_time_zone", "Asia/Shanghai")
      .option("max_memory_usage", "2000000000")
      .option("max_bytes_before_external_group_by", "1000000000")
      .load().coalesce(1)
    df.show(1000, false)
    import spark.implicits._
    df.where($"id"===328).distinct().coalesce(1).write.format(Configure.SPARK_CLICKHOUSE_FORMAT)
      .option("driver", Configure.clickhouseDriver)
      .option("url", Configure.clickhouseUrl)
      .option("user", Configure.clickhouseUser)
      .option("password", Configure.clickhousePassword)
      .option("table", "tbl_address")
      .option("use_server_time_zone", "false")
      .option("use_time_zone", "Asia/Shanghai")
      .option("max_memory_usage", "2000000000")
      .option("max_bytes_before_external_group_by", "1000000000")
      .mode(SaveMode.Append)
      .save();
  }
}
/**
 * @ClassName ClickHouseDataSourceV2
 * @Description 扩展SparkSQL DataSourceV2的ClickHouse数据源实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class ClickHouseDataSourceV2 extends DataSourceV2 with DataSourceRegister with ReadSupport with WriteSupport with StreamWriteSupport {
  /** 声明ClickHouse数据源的简称,使用方式为spark.read.format("clickhouse")... */
  override def shortName(): String = "clickhouse"
  /** 批处理方式下的数据读取 */
  override def createReader(options: DataSourceOptions): DataSourceReader = new CKReader(new CKOptions(options.asMap()))
  /** 批处理方式下的数据写入 */
  override def createWriter(writeUUID: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = Optional.of(new CKWriter(writeUUID, schema, mode, null, new CKOptions(options.asMap())))
  /** 流处理方式下的数据写入 */
  override def createStreamWriter(queryId: String, schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = new CKWriter(queryId, schema, null, mode, new CKOptions(options.asMap()))
}
/**
 * @ClassName CKReader
 * @Description 基于批处理方式的ClickHouse数据读取(此处只使用1个分区实现)
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKReader(options: CKOptions) extends DataSourceReader {
  //with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
  private val customSchema: java.lang.String = options.getCustomSchema
  private val helper = new CKHelper(options)
  import collection.JavaConversions._
  private val schema = if(StringUtils.isEmpty(customSchema)) {
    helper.getSparkTableSchema()
  } else {
    helper.getSparkTableSchema(new util.LinkedList[String](asJavaCollection(customSchema.split(","))))
  }
  override def readSchema(): StructType = schema
  override def planInputPartitions(): util.List[InputPartition[InternalRow]] = util.Arrays.asList(new CKInputPartition(schema, options))
}
/**
 * @ClassName CKInputPartition
 * @Description 基于批处理方式的ClickHouse分区实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKInputPartition(schema: StructType, options: CKOptions) extends InputPartition[InternalRow] {
  override def createPartitionReader(): InputPartitionReader[InternalRow] = new CKInputPartitionReader(schema, options)
}
/**
 * @ClassName CKInputPartitionReader
 * @Description 基于批处理方式的ClickHouse分区读取数据实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKInputPartitionReader(schema: StructType, options: CKOptions) extends InputPartitionReader[InternalRow] with Logging with Serializable{
  val helper = new CKHelper(options)
  var connection: ClickHouseConnection = null
  var st: ClickHouseStatement = null
  var rs: ResultSet = null
  override def next(): Boolean = {
    if (null == connection || connection.isClosed && null == st || st.isClosed && null == rs || rs.isClosed){
      connection = helper.getConnection
      st = connection.createStatement()
      rs = st.executeQuery(helper.getSelectStatement(schema))
      println(/**logInfo**/s"初始化ClickHouse连接.")
    }
    if(null != rs && !rs.isClosed) rs.next() else false
  }
  override def get(): InternalRow = {
    val fields = schema.fields
    val length = fields.length
    val record = new Array[Any](length)
    for (i <- 0 until length) {
      val field = fields(i)
      val name = field.name
      val dataType = field.dataType
      try {
        dataType match {
          case DataTypes.BooleanType => record(i) = rs.getBoolean(name)
          case DataTypes.DateType => record(i) = DateTimeUtils.fromJavaDate(rs.getDate(name))
          case DataTypes.DoubleType => record(i) = rs.getDouble(name)
          case DataTypes.FloatType => record(i) = rs.getFloat(name)
          case DataTypes.IntegerType => record(i) = rs.getInt(name)
          case DataTypes.LongType => record(i) = rs.getLong(name)
          case DataTypes.ShortType => record(i) = rs.getShort(name)
          case DataTypes.StringType => record(i) = UTF8String.fromString(rs.getString(name))
          case DataTypes.TimestampType => record(i) = DateTimeUtils.fromJavaTimestamp(rs.getTimestamp(name))
          case DataTypes.BinaryType => record(i) = rs.getBytes(name)
          case DataTypes.NullType => record(i) = StringUtils.EMPTY
        }
      } catch {
        case e: SQLException => logError(e.getStackTrace.mkString("", scala.util.Properties.lineSeparator, scala.util.Properties.lineSeparator))
      }
    }
    new GenericInternalRow(record)
  }
  override def close(): Unit = {helper.closeAll(connection, st, null, rs)}
}
/**
 * @ClassName CKWriter
 * @Description 支持Batch和Stream的数据写实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKWriter(writeUuidOrQueryId: String, schema: StructType, batchMode: SaveMode, streamMode: OutputMode, options: CKOptions) extends StreamWriter {
  private val isStreamMode:Boolean = if (null!=streamMode&&null==batchMode) true else false
  override def useCommitCoordinator(): Boolean = true
  override def onDataWriterCommit(message: WriterCommitMessage): Unit = {}
  override def createWriterFactory(): DataWriterFactory[InternalRow] = new CKDataWriterFactory(writeUuidOrQueryId, schema, batchMode, streamMode, options)
  /** Batch writer commit */
  override def commit(messages: Array[WriterCommitMessage]): Unit = {}
  /** Batch writer abort */
  override def abort(messages: Array[WriterCommitMessage]): Unit = {}
  /** Streaming writer commit */
  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
  /** Streaming writer abort */
  override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}
/**
 * @ClassName CKDataWriterFactory
 * @Description 写数据工厂,用来实例化CKDataWriter
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKDataWriterFactory(writeUUID: String, schema: StructType, batchMode: SaveMode, streamMode: OutputMode, options: CKOptions) extends DataWriterFactory[InternalRow] {
  override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = new CKDataWriter(writeUUID, schema, batchMode, streamMode, options)
}
/**
 * @ClassName CKDataWriter
 * @Description ClickHouse的数据写实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKDataWriter(writeUUID: String, schema: StructType, batchMode: SaveMode, streamMode: OutputMode, options: CKOptions) extends DataWriter[InternalRow] with Logging with Serializable {
  val helper = new CKHelper(options)
  val opType = options.getOpTypeField
  private val sqls = ArrayBuffer[String]()
  private val autoCreateTable: Boolean = options.autoCreateTable
  private val init = if (autoCreateTable) {
    val createSQL = helper.createTable(options.getFullTable, schema)
    println(/**logInfo**/s"==== 初始化表SQL:$createSQL")
    helper.executeUpdate(createSQL)
  }
  val fields = schema.fields
  override def commit(): WriterCommitMessage = {
    helper.executeUpdateBatch(sqls)
    val batchSQL = sqls.mkString("\n")
    // logDebug(batchSQL)
    println(batchSQL)
    new WriterCommitMessage{override def toString: String = s"批量插入SQL: $batchSQL"}
  }
  override def write(record: InternalRow): Unit = {
    if(StringUtils.isEmpty(opType)) {
      throw new RuntimeException("未传入opTypeField字段名称,无法确定数据持久化类型!")
    }
    var sqlStr: String = helper.getStatement(options.getFullTable, schema, record)
    logDebug(s"==== 拼接完成的INSERT SQL语句为:$sqlStr")
    try {
      if (StringUtils.isEmpty(sqlStr)) {
        val msg = "==== 拼接INSERT SQL语句失败,因为该语句为NULL或EMPTY!"
        logError(msg)
        throw new RuntimeException(msg)
      }
      Thread.sleep(options.getInterval())
      // 在流处理模式下操作
      if (null == batchMode) {
        if (streamMode == OutputMode.Append) {
          sqls += sqlStr
          // val state = helper.executeUpdate(sqlStr)
          // println(s"==== 在OutputMode.Append模式下执行:$sqlStr\n状态:$state")
        }
        else if(streamMode == OutputMode.Complete) {logError("==== 未实现OutputMode.Complete模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
        else if(streamMode == OutputMode.Update) {logError("==== 未实现OutputMode.Update模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
        else {logError(s"==== 未知模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
      // 在批处理模式下操作
      } else {
        if (batchMode == SaveMode.Append) {
          sqls += sqlStr
          //val state = helper.executeUpdate(sqlStr)
          //println(s"==== 在SaveMode.Append模式下执行:$sqlStr\n状态:$state")
        }
        else if(batchMode == SaveMode.Overwrite) {logError("==== 未实现SaveMode.Overwrite模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
        else if(batchMode == SaveMode.ErrorIfExists) {logError("==== 未实现SaveMode.ErrorIfExists模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
        else if(batchMode == SaveMode.Ignore) {logError("==== 未实现SaveMode.Ignore模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
        else {logError(s"==== 未知模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")}
      }
    } catch {
      case e: Exception => logError(e.getMessage)
    }
  }
  override def abort(): Unit = {}
}
/**
 * @ClassName CKOptions
 * @Description 从SparkSQL中DataSourceOptions中提取适用于ClickHouse的参数(spark.[read/write].options参数)
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKOptions(var originalMap: util.Map[String, String]) extends Logging with Serializable {
  val DRIVER_KEY: String = "driver"
  val URL_KEY: String = "url"
  val USER_KEY: String = "user"
  val PASSWORD_KEY: String = "password"
  val DATABASE_KEY: String = "database"
  val TABLE_KEY: String = "table"
  val AUTO_CREATE_TABLE = "autoCreateTable".toLowerCase
  val PATH_KEY = "path"
  val INTERVAL = "interval"
  val CUSTOM_SCHEMA_KEY: String = "customSchema".toLowerCase
  val WHERE_KEY: String = "where"
  val OP_TYPE_FIELD = "opTypeField".toLowerCase
  val PRIMARY_KEY = "primaryKey".toLowerCase
  def getValue[T](key: String, `type`: T): T = (if (originalMap.containsKey(key)) originalMap.get(key) else null).asInstanceOf[T]
  def getDriver: String = getValue(DRIVER_KEY, new String)
  def getURL: String = getValue(URL_KEY, new String)
  def getUser: String = getValue(USER_KEY, new String)
  def getPassword: String = getValue(PASSWORD_KEY, new String)
  def getDatabase: String = getValue(DATABASE_KEY, new String)
  def getTable: String = getValue(TABLE_KEY, new String)
  def autoCreateTable: Boolean = {
    originalMap.getOrDefault(AUTO_CREATE_TABLE, "false").toLowerCase match {
        case "true" => true
        case "false" => false
        case _ => false
      }
    }
  def getInterval(): Long = {originalMap.getOrDefault(INTERVAL, "200").toLong}
  def getPath: String = if(StringUtils.isEmpty(getValue(PATH_KEY, new String))) getTable else getValue(PATH_KEY, new String)
  def getWhere: String = getValue(WHERE_KEY, new String)
  def getCustomSchema: String = getValue(CUSTOM_SCHEMA_KEY, new String)
  def getOpTypeField: String = getValue(OP_TYPE_FIELD, new String)
  def getPrimaryKey: String = getValue(PRIMARY_KEY, new String)
  def getFullTable: String = {
    val database = getDatabase
    val table = getTable
    if (StringUtils.isEmpty(database) && !StringUtils.isEmpty(table)) table else if (!StringUtils.isEmpty(database) && !StringUtils.isEmpty(table)) database+"."+table else table
  }
  def asMap(): util.Map[String, String] = this.originalMap
  override def toString: String = originalMap.toString
}
/**
 * @ClassName CKHelper
 * @Description ClickHouse的JDBCHelper实现
 * @Created by MengYao
 * @Date 2020/5/17 16:34
 * @Version V1.0
 */
class CKHelper(options: CKOptions) extends Logging with Serializable {
  private val opType: String = options.getOpTypeField
  private val id: String = options.getPrimaryKey
  private var connection: ClickHouseConnection = getConnection
  def getConnection: ClickHouseConnection = {
    val url = options.getURL
    val ds = new ClickHouseDataSource(url, new ClickHouseProperties())
    ds.getConnection(options.getUser, options.getPassword)
  }
  def createTable(table: String, schema: StructType): String = {
    val cols = ArrayBuffer[String]()
    for (field <- schema.fields) {
      val dataType = field.dataType
      val ckColName = field.name
      if (ckColName!=opType) {
        var ckColType = getClickhouseSqlType(dataType)
        if (!StringUtils.isEmpty(ckColType)) {
          if (ckColType.toLowerCase=="string") {ckColType="Nullable("+ckColType+")"}
        }
        cols += ckColName+" "+ ckColType
      }
    }
    s"CREATE TABLE IF NOT EXISTS $table(${cols.mkString(",")},sign Int8,version UInt64) ENGINE=VersionedCollapsingMergeTree(sign, version) ORDER BY $id"
  }
  def getSparkTableSchema(customFields: util.LinkedList[String] = null): StructType = {
    import collection.JavaConversions._
    val list: util.LinkedList[Triplet[String, String, String]] = getCKTableSchema(customFields)
    var fields = ArrayBuffer[StructField]()
    for(trp <- list) {
      fields += StructField(trp.getValue0, getSparkSqlType(trp.getValue1))
    }
    StructType(fields)
  }
  private def getFieldValue(fieldName: String, schema: StructType, data:InternalRow): Any = {
    var flag = true
    var fieldValue:String = null
    val fields = schema.fields
    for(i <- 0 until fields.length if flag) {
      val field = fields(i)
      if(fieldName==field.name) {
        fieldValue = field.dataType match {
          case DataTypes.BooleanType => if (data.isNullAt(i)) "NULL" else s"${data.getBoolean(i)}"
          case DataTypes.DoubleType => if (data.isNullAt(i)) "NULL" else s"${data.getDouble(i)}"
          case DataTypes.FloatType => if (data.isNullAt(i)) "NULL" else s"${data.getFloat(i)}"
          case DataTypes.IntegerType => if (data.isNullAt(i)) "NULL" else s"${data.getInt(i)}"
          case DataTypes.LongType => if (data.isNullAt(i)) "NULL" else s"${data.getLong(i)}"
          case DataTypes.ShortType => if (data.isNullAt(i)) "NULL" else s"${data.getShort(i)}"
          case DataTypes.StringType => if (data.isNullAt(i)) "NULL" else s"${data.getUTF8String(i).toString.trim}"
          case DataTypes.DateType => if (data.isNullAt(i)) "NULL" else s"'${new SimpleDateFormat("yyyy-MM-dd").format(new Date(data.get(i, DateType).asInstanceOf[Date].getTime / 1000))}'"
          case DataTypes.TimestampType => if (data.isNullAt(i)) "NULL" else s"${new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date(data.getLong(i) / 1000))}"
          case DataTypes.BinaryType => if (data.isNullAt(i)) "NULL" else s"${data.getBinary(i)}"
          case DataTypes.NullType => "NULL"
        }
        flag = false
      }
    }
    fieldValue
  }
  def getStatement(table: String, schema: StructType, record: InternalRow): String = {
    val opTypeValue: String = getFieldValue(opType, schema, record).toString
    if (opTypeValue.toLowerCase()=="insert") {getInsertStatement(table, schema, record)}
    else if (opTypeValue.toLowerCase()=="delete") {getUpdateStatement(table, schema, record)}
    else if (opTypeValue.toLowerCase()=="update") {getDeleteStatement(table, schema, record)}
    else {""}
  }
  def getSelectStatement(schema: StructType):String = {
    s"SELECT ${schema.fieldNames.mkString(",")} FROM ${options.getFullTable}"
  }
  def getInsertStatement(table:String, schema: StructType, data:InternalRow):String = {
    val fields = schema.fields
    val names = ArrayBuffer[String]()
    val values = ArrayBuffer[String]()
    // 表示DataFrame中的字段与数据库中的字段相同,拼接SQL语句时使用全量字段拼接
    if (data.numFields==fields.length) {
    } else {// 表示DataFrame中的字段与数据库中的字段不同,拼接SQL时需要仅拼接DataFrame中有的字段到SQL中
    }
    for(i <- 0 until fields.length) {
      val field = fields(i)
      val fieldType = field.dataType
      val fieldName = field.name
      if (fieldName!=opType) {
        val fieldValue = fieldType match {
          case DataTypes.BooleanType => if(data.isNullAt(i)) "NULL" else s"${data.getBoolean(i)}"
          case DataTypes.DoubleType => if(data.isNullAt(i)) "NULL" else s"${data.getDouble(i)}"
          case DataTypes.FloatType => if(data.isNullAt(i)) "NULL" else s"${data.getFloat(i)}"
          case DataTypes.IntegerType => if(data.isNullAt(i)) "NULL" else s"${data.getInt(i)}"
          case DataTypes.LongType => if(data.isNullAt(i)) "NULL" else s"${data.getLong(i)}"
          case DataTypes.ShortType => if(data.isNullAt(i)) "NULL" else s"${data.getShort(i)}"
          case DataTypes.StringType => if(data.isNullAt(i)) "NULL" else s"'${data.getUTF8String(i).toString.trim}'"
          case DataTypes.DateType => if(data.isNullAt(i)) "NULL" else s"'${new SimpleDateFormat("yyyy-MM-dd").format(new Date(data.get(i, DateType).asInstanceOf[Date].getTime/1000))}'"
          case DataTypes.TimestampType => if(data.isNullAt(i)) "NULL" else s"'${new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date(data.getLong(i)/1000))}'"
          case DataTypes.BinaryType => if(data.isNullAt(i)) "NULL" else s"${data.getBinary(i)}"
          case DataTypes.NullType => "NULL"
        }
        names += fieldName
        values += fieldValue
      }
    }
    if (names.length > 0 && values.length > 0) {
      names += ("sign","version")
      values += ("1", System.currentTimeMillis().toString)
    }
    s"INSERT INTO $table(${names.mkString(",")}) VALUES(${values.mkString(",")})"
  }
  def getDeleteStatement(table:String, schema: StructType, data:InternalRow):String = {
    val fields = schema.fields
    val primaryKeyFields = if(options.getPrimaryKey.isEmpty) {fields.filter(field => field.name=="id")} else {fields.filter(field => field.name==options.getPrimaryKey)}
    if (primaryKeyFields.length>0) {
      val primaryKeyField = primaryKeyFields(0)
      val primaryKeyValue = getFieldValue(primaryKeyField.name, schema, data)
      s"ALTER TABLE $table DELETE WHERE ${primaryKeyField.name} = $primaryKeyValue"
    } else {
      logError("==== 找不到主键,无法生成删除SQL!")
      ""
    }
  }
  def getUpdateStatement(table:String, schema: StructType, data:InternalRow):String = {
    val fields = schema.fields
    val primaryKeyFields = if(options.getPrimaryKey.isEmpty) {fields.filter(field => field.name=="id")} else {fields.filter(field => field.name==options.getPrimaryKey)}
    if (primaryKeyFields.length>0) {
      val primaryKeyField = primaryKeyFields(0)
      val primaryKeyValue = getFieldValue(primaryKeyField.name, schema, data)
      val noPrimaryKeyFields = fields.filter(field=>field.name!=primaryKeyField.name)
      var sets = ArrayBuffer[String]()
      for(i <- 0 until noPrimaryKeyFields.length) {
        val noPrimaryKeyField = noPrimaryKeyFields(i)
        val set = noPrimaryKeyField.name+"="+getFieldValue(noPrimaryKeyField.name, schema, data).toString
        sets += set
      }
      sets.remove(sets.length-1)
      s"ALTER TABLE $table UPDATE ${sets.mkString(" AND ")} WHERE ${primaryKeyField.name}=$primaryKeyValue"
    } else {
      logError("==== 找不到主键,无法生成修改SQL!")
      ""
    }
  }
  def getCKTableSchema(customFields: util.LinkedList[String] = null): util.LinkedList[Triplet[String, String, String]] = {
    val fields = new util.LinkedList[Triplet[String, String, String]]
    var connection: ClickHouseConnection = null
    var st: ClickHouseStatement = null
    var rs: ClickHouseResultSet = null
    var metaData: ClickHouseResultSetMetaData = null
    try {
      connection = getConnection
      st = connection.createStatement
      val sql = s"SELECT * FROM ${options.getFullTable} WHERE 1=0"
      rs = st.executeQuery(sql).asInstanceOf[ClickHouseResultSet]
      metaData = rs.getMetaData.asInstanceOf[ClickHouseResultSetMetaData]
      val columnCount = metaData.getColumnCount
      for (i <- 1 to columnCount) {
        val columnName = metaData.getColumnName(i)
        val sqlTypeName = metaData.getColumnTypeName(i)
        val javaTypeName = ClickHouseDataType.fromTypeString(sqlTypeName).getJavaClass.getSimpleName
        if (null != customFields && customFields.size > 0) {
          if(fields.contains(columnName)) fields.add(new Triplet(columnName, sqlTypeName, javaTypeName))
        } else {
          fields.add(new Triplet(columnName, sqlTypeName, javaTypeName))
        }
      }
    } catch {
      case e: Exception => e.printStackTrace()
    } finally {
      closeAll(connection, st, null, rs)
    }
    fields
  }
  def executeUpdateBatch(sqls: ArrayBuffer[String]): Unit = {
    // 拼接Batch SQL:VALUES()()...
    val batchSQL = new StringBuilder()
    for(i <- 0 until sqls.length) {
      val line = sqls(i)
      var offset: Int = 0
      if (!StringUtils.isEmpty(line) && line.contains("VALUES")) {
        val offset = line.indexOf("VALUES")
        if(i==0) {
          val prefix = line.substring(0, offset+6)
          batchSQL.append(prefix)
        }
        val suffix = line.substring(offset+6)
        batchSQL.append(suffix)
      }
    }
    var st: ClickHouseStatement = null;
    try {
      if(null==connection||connection.isClosed) {connection = getConnection}
      st = connection createStatement()
      st.executeUpdate(batchSQL.toString())
    } catch {
      case e: Exception => logError(s"执行异常:$sqls\n${e.getMessage}")
    } finally {
      //closeAll(connection, st)
    }
  }
  def executeUpdate(sql: String): Int = {
    var state = 0;
    var st: ClickHouseStatement = null;
    try {
      if(null==connection||connection.isClosed) {connection = getConnection}
      st = connection createStatement()
      state = st.executeUpdate(sql)
    } catch {
      case e: Exception => logError(s"执行异常:$sql\n${e.getMessage}")
    } finally {
      //closeAll(connection, st)
    }
    state
  }
  def close(connection: Connection): Unit = closeAll(connection)
  def close(st: Statement): Unit = closeAll(null, st, null, null)
  def close(ps: PreparedStatement): Unit = closeAll(null, null, ps, null)
  def close(rs: ResultSet): Unit = closeAll(null, null, null, rs)
  def closeAll(connection: Connection=null, st: Statement=null, ps: PreparedStatement=null, rs: ResultSet=null): Unit = {
    try {
      if (rs != null && !rs.isClosed) rs.close()
      if (ps != null && !ps.isClosed) ps.close()
      if (st != null && !st.isClosed) st.close()
      if (connection != null && !connection.isClosed) connection.close()
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }
  /**
   * IntervalYear      (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalQuarter   (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalMonth     (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalWeek      (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalDay       (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalHour      (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalMinute    (Types.INTEGER,   Integer.class,    true,  19,  0),
   * IntervalSecond    (Types.INTEGER,   Integer.class,    true,  19,  0),
   * UInt64            (Types.BIGINT,    BigInteger.class, false, 19,  0),
   * UInt32            (Types.INTEGER,   Long.class,       false, 10,  0),
   * UInt16            (Types.SMALLINT,  Integer.class,    false,  5,  0),
   * UInt8             (Types.TINYINT,   Integer.class,    false,  3,  0),
   * Int64             (Types.BIGINT,    Long.class,       true,  20,  0, "BIGINT"),
   * Int32             (Types.INTEGER,   Integer.class,    true,  11,  0, "INTEGER", "INT"),
   * Int16             (Types.SMALLINT,  Integer.class,    true,   6,  0, "SMALLINT"),
   * Int8              (Types.TINYINT,   Integer.class,    true,   4,  0, "TINYINT"),
   * Date              (Types.DATE,      Date.class,       false, 10,  0),
   * DateTime          (Types.TIMESTAMP, Timestamp.class,  false, 19,  0, "TIMESTAMP"),
   * Enum8             (Types.VARCHAR,   String.class,     false,  0,  0),
   * Enum16            (Types.VARCHAR,   String.class,     false,  0,  0),
   * Float32           (Types.FLOAT,     Float.class,      true,   8,  8, "FLOAT"),
   * Float64           (Types.DOUBLE,    Double.class,     true,  17, 17, "DOUBLE"),
   * Decimal32         (Types.DECIMAL,   BigDecimal.class, true,   9,  9),
   * Decimal64         (Types.DECIMAL,   BigDecimal.class, true,  18, 18),
   * Decimal128        (Types.DECIMAL,   BigDecimal.class, true,  38, 38),
   * Decimal           (Types.DECIMAL,   BigDecimal.class, true,   0,  0, "DEC"),
   * UUID              (Types.OTHER,     UUID.class,       false, 36,  0),
   * String            (Types.VARCHAR,   String.class,     false,  0,  0, "LONGBLOB", "MEDIUMBLOB", "TINYBLOB", "MEDIUMTEXT", "CHAR", "VARCHAR", "TEXT", "TINYTEXT", "LONGTEXT", "BLOB"),
   * FixedString       (Types.CHAR,      String.class,     false, -1,  0, "BINARY"),
   * Nothing           (Types.NULL,      Object.class,     false,  0,  0),
   * Nested            (Types.STRUCT,    String.class,     false,  0,  0),
   * Tuple             (Types.OTHER,     String.class,     false,  0,  0),
   * Array             (Types.ARRAY,     Array.class,      false,  0,  0),
   * AggregateFunction (Types.OTHER,     String.class,     false,  0,  0),
   * Unknown           (Types.OTHER,     String.class,     false,  0,  0);
   *
   * @param clickhouseDataType
   * @return
   */
  private def getSparkSqlType(clickhouseDataType: String) = clickhouseDataType match {
    case "IntervalYear" => DataTypes.IntegerType
    case "IntervalQuarter" => DataTypes.IntegerType
    case "IntervalMonth" => DataTypes.IntegerType
    case "IntervalWeek" => DataTypes.IntegerType
    case "IntervalDay" => DataTypes.IntegerType
    case "IntervalHour" => DataTypes.IntegerType
    case "IntervalMinute" => DataTypes.IntegerType
    case "IntervalSecond" => DataTypes.IntegerType
    case "UInt64" => DataTypes.LongType //DataTypes.IntegerType;
    case "UInt32" => DataTypes.LongType
    case "UInt16" => DataTypes.IntegerType
    case "UInt8" => DataTypes.IntegerType
    case "Int64" => DataTypes.LongType
    case "Int32" => DataTypes.IntegerType
    case "Int16" => DataTypes.IntegerType
    case "Int8" => DataTypes.IntegerType
    case "Date" => DataTypes.DateType
    case "DateTime" => DataTypes.TimestampType
    case "Enum8" => DataTypes.StringType
    case "Enum16" => DataTypes.StringType
    case "Float32" => DataTypes.FloatType
    case "Float64" => DataTypes.DoubleType
    case "Decimal32" => DataTypes.createDecimalType
    case "Decimal64" => DataTypes.createDecimalType
    case "Decimal128" => DataTypes.createDecimalType
    case "Decimal" => DataTypes.createDecimalType
    case "UUID" => DataTypes.StringType
    case "String" => DataTypes.StringType
    case "FixedString" => DataTypes.StringType
    case "Nothing" => DataTypes.NullType
    case "Nested" => DataTypes.StringType
    case "Tuple" => DataTypes.StringType
    case "Array" => DataTypes.StringType
    case "AggregateFunction" => DataTypes.StringType
    case "Unknown" => DataTypes.StringType
    case _ => DataTypes.NullType
  }
  private def getClickhouseSqlType(sparkDataType: DataType) = sparkDataType match {
    case  DataTypes.ByteType => "Int8"
    case  DataTypes.ShortType => "Int16"
    case  DataTypes.IntegerType => "Int32"
    case  DataTypes.FloatType => "Float32"
    case  DataTypes.DoubleType => "Float64"
    case  DataTypes.LongType => "Int64"
    case  DataTypes.DateType => "DateTime"
    case  DataTypes.TimestampType => "DateTime"
    case  DataTypes.StringType => "String"
    case  DataTypes.NullType => "String"
  }
}

sparksql 扩展nebula源

创建NebulaDataSource类

import java.util.Map.Entry
import java.util.Optional
import com.vesoft.nebula.connector.exception.IllegalOptionException
import com.vesoft.nebula.connector.reader.{NebulaDataSourceEdgeReader, NebulaDataSourceVertexReader}
import com.vesoft.nebula.connector.writer.{NebulaDataSourceEdgeWriter, NebulaDataSourceVertexWriter}
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.reader.DataSourceReader
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory
import scala.collection.JavaConversions.iterableAsScalaIterable
class NebulaDataSource
    extends DataSourceV2
    with ReadSupport
    with WriteSupport
    with DataSourceRegister {
  private val LOG = LoggerFactory.getLogger(this.getClass)
  /**
    * The string that represents the format that nebula data source provider uses.
    */
  override def shortName(): String = "nebula"
  /**
    * Creates a {@link DataSourceReader} to scan the data from Nebula Graph.
    */
  override def createReader(options: DataSourceOptions): DataSourceReader = {
    val nebulaOptions = getNebulaOptions(options, OperaType.READ)
    val dataType      = nebulaOptions.dataType
    LOG.info("create reader")
    LOG.info(s"options ${options.asMap()}")
    if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) {
      new NebulaDataSourceVertexReader(nebulaOptions)
    } else {
      new NebulaDataSourceEdgeReader(nebulaOptions)
    }
  }
  /**
    * Creates an optional {@link DataSourceWriter} to save the data to Nebula Graph.
    */
  override def createWriter(writeUUID: String,
                            schema: StructType,
                            mode: SaveMode,
                            options: DataSourceOptions): Optional[DataSourceWriter] = {
    val nebulaOptions = getNebulaOptions(options, OperaType.WRITE)
    val dataType      = nebulaOptions.dataType
    if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
      LOG.warn(s"Currently do not support mode")
    }
    LOG.info("create writer")
    LOG.info(s"options ${options.asMap()}")
    if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) {
      val vertexFiled = nebulaOptions.vertexField
      val vertexIndex: Int = {
        var index: Int = -1
        for (i <- schema.fields.indices) {
          if (schema.fields(i).name.equals(vertexFiled)) {
            index = i
          }
        }
        if (index < 0) {
          throw new IllegalOptionException(
            s" vertex field ${vertexFiled} does not exist in dataframe")
        }
        index
      }
      Optional.of(new NebulaDataSourceVertexWriter(nebulaOptions, vertexIndex, schema))
    } else {
      val srcVertexFiled = nebulaOptions.srcVertexField
      val dstVertexField = nebulaOptions.dstVertexField
      val rankExist      = !nebulaOptions.rankField.isEmpty
      val edgeFieldsIndex = {
        var srcIndex: Int  = -1
        var dstIndex: Int  = -1
        var rankIndex: Int = -1
        for (i <- schema.fields.indices) {
          if (schema.fields(i).name.equals(srcVertexFiled)) {
            srcIndex = i
          }
          if (schema.fields(i).name.equals(dstVertexField)) {
            dstIndex = i
          }
          if (rankExist) {
            if (schema.fields(i).name.equals(nebulaOptions.rankField)) {
              rankIndex = i
            }
          }
        }
        // check src filed and dst field
        if (srcIndex < 0 || dstIndex < 0) {
          throw new IllegalOptionException(
            s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe")
        }
        // check rank field
        if (rankExist && rankIndex < 0) {
          throw new IllegalOptionException(s"rank field does not exist in dataframe")
        }
        if (!rankExist) {
          (srcIndex, dstIndex, Option.empty)
        } else {
          (srcIndex, dstIndex, Option(rankIndex))
        }
      }
      Optional.of(
        new NebulaDataSourceEdgeWriter(nebulaOptions,
                                       edgeFieldsIndex._1,
                                       edgeFieldsIndex._2,
                                       edgeFieldsIndex._3,
                                       schema))
    }
  }
  /**
    * construct nebula options with DataSourceOptions
    */
  def getNebulaOptions(options: DataSourceOptions, operateType: OperaType.Value): NebulaOptions = {
    var parameters: Map[String, String] = Map()
    for (entry: Entry[String, String] <- options.asMap().entrySet) {
      parameters += (entry.getKey -> entry.getValue)
    }
    val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters))(operateType)
    nebulaOptions
  }
}

创建NebulaSourceReader类

import java.util
import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions, NebulaUtils}
import com.vesoft.nebula.connector.nebula.MetaProvider
import com.vesoft.nebula.meta.ColumnDef
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.slf4j.LoggerFactory
import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._
/**
  * Base class of Nebula Source Reader
  */
abstract class NebulaSourceReader(nebulaOptions: NebulaOptions) extends DataSourceReader {
  private val LOG = LoggerFactory.getLogger(this.getClass)
  private var datasetSchema: StructType = _
  override def readSchema(): StructType = {
    datasetSchema = getSchema(nebulaOptions)
    LOG.info(s"dataset's schema: $datasetSchema")
    datasetSchema
  }
  protected def getSchema: StructType = getSchema(nebulaOptions)
  ## return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula.
  def getSchema(nebulaOptions: NebulaOptions): StructType = {
    val returnCols                      = nebulaOptions.getReturnCols
    val noColumn                        = nebulaOptions.noColumn
    val fields: ListBuffer[StructField] = new ListBuffer[StructField]
    val metaProvider                    = new MetaProvider(nebulaOptions.getMetaAddress)
    import scala.collection.JavaConverters._
    var schemaCols: Seq[ColumnDef] = Seq()
    val isVertex                   = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType)
    // construct vertex or edge default prop
    if (isVertex) {
      fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false))
    } else {
      fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false))
      fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false))
      fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false))
    }
    var dataSchema: StructType = null
    // read no column
    if (noColumn) {
      dataSchema = new StructType(fields.toArray)
      return dataSchema
    }
    // get tag schema or edge schema
    val schema = if (isVertex) {
      metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label)
    } else {
      metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label)
    }
    schemaCols = schema.columns.asScala
    // read all columns
    if (returnCols.isEmpty) {
      schemaCols.foreach(columnDef => {
        LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ")
        fields.append(
          DataTypes.createStructField(new String(columnDef.getName),
                                      NebulaUtils.convertDataType(columnDef.getType),
                                      true))
      })
    } else {
      for (col: String <- returnCols) {
        fields.append(
          DataTypes
            .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true))
      }
    }
    dataSchema = new StructType(fields.toArray)
    dataSchema
  }
}
/**
  * DataSourceReader for Nebula Vertex
  */
class NebulaDataSourceVertexReader(nebulaOptions: NebulaOptions)
    extends NebulaSourceReader(nebulaOptions) {
  override def planInputPartitions(): util.List[InputPartition[InternalRow]] = {
    val partitionNum = nebulaOptions.partitionNums.toInt
    val partitions = for (index <- 1 to partitionNum)
      yield {
        new NebulaVertexPartition(index, nebulaOptions, getSchema)
      }
    partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava
  }
}
/**
  * DataSourceReader for Nebula Edge
  */
class NebulaDataSourceEdgeReader(nebulaOptions: NebulaOptions)
    extends NebulaSourceReader(nebulaOptions) {
  override def planInputPartitions(): util.List[InputPartition[InternalRow]] = {
    val partitionNum = nebulaOptions.partitionNums.toInt
    val partitions = for (index <- 1 to partitionNum)
      yield new NebulaEdgePartition(index, nebulaOptions, getSchema)
    partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava
  }
}

创建NebulaSourceWriter类

import com.vesoft.nebula.connector.NebulaOptions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{
  DataSourceWriter,
  DataWriter,
  DataWriterFactory,
  WriterCommitMessage
}
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory
/**
  * creating and initializing the actual Nebula vertex writer at executor side
  */
class NebulaVertexWriterFactory(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType)
    extends DataWriterFactory[InternalRow] {
  override def createDataWriter(partitionId: Int,
                                taskId: Long,
                                epochId: Long): DataWriter[InternalRow] = {
    new NebulaVertexWriter(nebulaOptions, vertexIndex, schema)
  }
}
/**
  * creating and initializing the actual Nebula edge writer at executor side
  */
class NebulaEdgeWriterFactory(nebulaOptions: NebulaOptions,
                              srcIndex: Int,
                              dstIndex: Int,
                              rankIndex: Option[Int],
                              schema: StructType)
    extends DataWriterFactory[InternalRow] {
  override def createDataWriter(partitionId: Int,
                                taskId: Long,
                                epochId: Long): DataWriter[InternalRow] = {
    new NebulaEdgeWriter(nebulaOptions, srcIndex, dstIndex, rankIndex, schema)
  }
}
/**
  * nebula vertex writer to create factory
  */
class NebulaDataSourceVertexWriter(nebulaOptions: NebulaOptions,
                                   vertexIndex: Int,
                                   schema: StructType)
    extends DataSourceWriter {
  private val LOG = LoggerFactory.getLogger(this.getClass)
  override def createWriterFactory(): DataWriterFactory[InternalRow] = {
    new NebulaVertexWriterFactory(nebulaOptions, vertexIndex, schema)
  }
  override def commit(messages: Array[WriterCommitMessage]): Unit = {
    LOG.debug(s"${messages.length}")
    for (msg <- messages) {
      val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage]
      LOG.info(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}")
    }
  }
  override def abort(messages: Array[WriterCommitMessage]): Unit = {
    LOG.error("NebulaDataSourceVertexWriter abort")
  }
}
/**
  * nebula edge writer to create factory
  */
class NebulaDataSourceEdgeWriter(nebulaOptions: NebulaOptions,
                                 srcIndex: Int,
                                 dstIndex: Int,
                                 rankIndex: Option[Int],
                                 schema: StructType)
    extends DataSourceWriter {
  private val LOG = LoggerFactory.getLogger(this.getClass)
  override def createWriterFactory(): DataWriterFactory[InternalRow] = {
    new NebulaEdgeWriterFactory(nebulaOptions, srcIndex, dstIndex, rankIndex, schema)
  }
  override def commit(messages: Array[WriterCommitMessage]): Unit = {
    LOG.debug(s"${messages.length}")
    for (msg <- messages) {
      val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage]
      LOG.info(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}")
    }
  }
  override def abort(messages: Array[WriterCommitMessage]): Unit = {
    LOG.error("NebulaDataSourceEdgeWriter abort")
  }
}

完整代码参考:

https://github.com/vesoft-inc/nebula-spark-utils/tree/master/nebula-spark-connector/
相关文章
|
4月前
|
消息中间件 分布式计算 Kafka
SparkStreaming(SparkStreaming概述、入门、Kafka数据源、DStream转换、输出、关闭)
SparkStreaming(SparkStreaming概述、入门、Kafka数据源、DStream转换、输出、关闭)(一)
|
1月前
|
分布式计算 大数据 数据处理
如何在 PySpark 中实现自定义转换
【8月更文挑战第14天】
31 4
|
4月前
|
存储 测试技术 API
Apache Hudi 负载类Payload使用案例剖析
Apache Hudi 负载类Payload使用案例剖析
115 4
|
4月前
|
Java 数据库连接 数据库
Flink全托管,holo 库同步到另一个库,报错failed to get user from ak 亲,请问是哪种权限缺失?Flink 配置中使用的是holo. jdbc 的user和password 。
Flink全托管,holo 库同步到另一个库,报错failed to get user from ak 亲,请问是哪种权限缺失?Flink 配置中使用的是holo. jdbc 的user和password 。
62 1
|
4月前
|
消息中间件 关系型数据库 MySQL
SparkStreaming【实例演示】
SparkStreaming【实例演示】
|
消息中间件 存储 分布式计算
Spark学习---6、SparkStreaming(SparkStreaming概述、入门、Kafka数据源、DStream转换、输出、关闭)(二)
Spark学习---6、SparkStreaming(SparkStreaming概述、入门、Kafka数据源、DStream转换、输出、关闭)(二)
|
消息中间件 分布式计算 Kafka
Spark学习---6、SparkStreaming(SparkStreaming概述、入门、Kafka数据源、DStream转换、输出、关闭)(一)
Spark学习---6、SparkStreaming(SparkStreaming概述、入门、Kafka数据源、DStream转换、输出、关闭)(一)
|
分布式计算 资源调度 Hadoop
查看Spark任务的详细信息
在学习Spark的过程中,查看任务的DAG、stage、task等详细信息是学习的重要手段,在此做个小结
354 0
查看Spark任务的详细信息
|
分布式计算 监控 Java
日志分析实战之清洗日志小实例3:如何在spark shell中导入自定义包
日志分析实战之清洗日志小实例3:如何在spark shell中导入自定义包
164 0
日志分析实战之清洗日志小实例3:如何在spark shell中导入自定义包
|
SQL 分布式计算 Java
SparkSQL 读写_Hive_写入数据_编码和配置 | 学习笔记
快速学习 SparkSQL 读写_Hive_写入数据_编码和配置
290 0
SparkSQL 读写_Hive_写入数据_编码和配置 | 学习笔记