基于sparkSql DataSourceV2实现输入源SparkSQL的DataSourceV2的实现与StructuredStreaming自定义数据源如出一辙,思想是一样的,但是具体实现有所不同,主要步骤如下:
第一步:继承DataSourceV2和ReadSupport创建XXXDataSource类,重写ReadSupport的creatReader方法,用来返回自定义的DataSourceReader类,如返回自定义XXXDataSourceReader实例 第二步:继承DataSourceReader创建XXXDataSourceReader类,重写DataSourceReader的readSchema方法用来返回数据源的schema,重写DataSourceReader的createDataReaderFactories用来返回多个自定义DataReaderFactory实例 第三步:继承DataReaderFactory创建DataReader工厂类,如XXXDataReaderFactory,重写DataReaderFactory的createDataReader方法,返回自定义DataRader实例 第四步:继承DataReader类创建自定义的DataReader,如XXXDataReader,重写DataReader的next()方法,用来告诉Spark是否有下条数据,用来触发get()方法,重写DataReader的get()方法获取数据,重写DataReader的close()方法用来关闭资源
基于DataSourceV2实现输出源
基于DataSourceV2实现自定义的输出源,需要以下几个步骤:
第一步:继承DataSourceV2和WriteSupport创建XXXDataSource,重写createWriter方法用来返回自定义的DataSourceWriter 第二步:继承DataSourceWriter创建XXXDataSourceWriter类,重写createWriterFactory返回自定义的DataWriterFactory,重写commit方法,用来提交整个事务。重写abort方法,用来做事务回滚 第三步:继承DataWriterFactory创建XXXDataWriterFactory类,重写createWriter方法返回自定义的DataWriter 第四步:继承DataWriter创建XXXDataWriter类,重写write方法,用来将数据写出,重写commit方法用来提交事务,重写abort方法用来做事务回滚
SparkSql自定义RestDataSource代码
package com.hollysys.spark.sql.datasource.rest import java.math.BigDecimal import java.util import java.util.Optional import com.alibaba.fastjson.{JSONArray, JSONObject, JSONPath} import org.apache.http.client.fluent.Request import org.apache.http.entity.ContentType import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.types.StructType /** * 基于Rest的Spark SQL DataSource */ class RestDataSource extends DataSourceV2 with ReadSupport with WriteSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new RestDataSourceReader( options.get("url").get(), options.get("params").get(), options.get("xPath").get(), options.get("schema").get() ) override def createWriter(jobId: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = Optional.of(new RestDataSourceWriter) } /** * 创建RestDataSourceReader * * @param url REST服务的的api * @param params 请求需要的参数 * @param xPath JSON数据的xPath * @param schemaString 用户传入的schema字符串 */ class RestDataSourceReader(url: String, params: String, xPath: String, schemaString: String) extends DataSourceReader { // 使用StructType.fromDDL方法将schema字符串转成StructType类型 var requiredSchema: StructType = StructType.fromDDL(schemaString) /** * 生成schema * * @return schema */ override def readSchema(): StructType = requiredSchema /** * 创建工厂类 * * @return List[实例] */ override def createDataReaderFactories(): util.List[DataReaderFactory[Row]] = { import collection.JavaConverters._ Seq( new RestDataReaderFactory(url, params, xPath).asInstanceOf[DataReaderFactory[Row]] ).asJava } } /** * RestDataReaderFactory工厂类 * * @param url REST服务的的api * @param params 请求需要的参数 * @param xPath JSON数据的xPath */ class RestDataReaderFactory(url: String, params: String, xPath: String) extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RestDataReader(url, params, xPath) } /** * RestDataReader类 * * @param url REST服务的的api * @param params 请求需要的参数 * @param xPath JSON数据的xPath */ class RestDataReader(url: String, params: String, xPath: String) extends DataReader[Row] { // 使用Iterator模拟数据 val data: Iterator[Seq[AnyRef]] = getIterator override def next(): Boolean = { data.hasNext } override def get(): Row = { val seq = data.next().map { // 浮点类型会自动转为BigDecimal,导致Spark无法转换 case decimal: BigDecimal => decimal.doubleValue() case x => x } Row(seq: _*) } override def close(): Unit = { println("close source") } def getIterator: Iterator[Seq[AnyRef]] = { import scala.collection.JavaConverters._ val res: List[AnyRef] = RestDataSource.requestData(url, params, xPath) res.map(r => { r.asInstanceOf[JSONObject].asScala.values.toList }).toIterator } } /** * * RestDataSourceWriter */ class RestDataSourceWriter extends DataSourceWriter { /** * 创建RestDataWriter工厂类 * * @return RestDataWriterFactory */ override def createWriterFactory(): DataWriterFactory[Row] = new RestDataWriterFactory /** * commit * * @param messages 所有分区提交的commit信息 * 触发一次 */ override def commit(messages: Array[WriterCommitMessage]): Unit = ??? /** * * abort * * @param messages 当write异常时调用 */ override def abort(messages: Array[WriterCommitMessage]): Unit = ??? } /** * DataWriterFactory工厂类 */ class RestDataWriterFactory extends DataWriterFactory[Row] { /** * 创建DataWriter * * @param partitionId 分区ID * @param attemptNumber 重试次数 * @return DataWriter * 每个分区创建一个RestDataWriter实例 */ override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = new RestDataWriter(partitionId, attemptNumber) } /** * RestDataWriter * * @param partitionId 分区ID * @param attemptNumber 重试次数 */ class RestDataWriter(partitionId: Int, attemptNumber: Int) extends DataWriter[Row] { /** * write * * @param record 单条记录 * 每条记录都会触发该方法 */ override def write(record: Row): Unit = { println(record) } /** * commit * * @return commit message * 每个分区触发一次 */ override def commit(): WriterCommitMessage = { RestWriterCommitMessage(partitionId, attemptNumber) } /** * 回滚:当write发生异常时触发该方法 */ override def abort(): Unit = { println("abort 方法被出发了") } } case class RestWriterCommitMessage(partitionId: Int, attemptNumber: Int) extends WriterCommitMessage object RestDataSource { def requestData(url: String, params: String, xPath: String): List[AnyRef] = { import scala.collection.JavaConverters._ val response = Request.Post(url).bodyString(params, ContentType.APPLICATION_JSON).execute() JSONPath.read(response.returnContent().asString(), xPath) .asInstanceOf[JSONArray].asScala.toList } } object RestDataSourceTest { def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .master("local[2]") .appName(this.getClass.getSimpleName) .getOrCreate() val df = spark.read .format("com.hollysys.spark.sql.datasource.rest.RestDataSource") .option("url", "http://model-opcua-hollysysdigital-test.hiacloud.net.cn/aggquery/query/queryPointHistoryData") .option("params", "{\n \"startTime\": \"1543887720000\",\n \"endTime\": \"1543891320000\",\n \"maxSizePerNode\": 1000,\n \"nodes\": [\n {\n \"uri\": \"/SymLink-10000012030100000-device/5c174da007a54e0001035ddd\"\n }\n ]\n}") .option("xPath", "$.result.historyData") //`response` ARRAY<STRUCT<`historyData`:ARRAY<STRUCT<`s`:INT,`t`:LONG,`v`:FLOAT>>>> .option("schema", "`s` INT,`t` LONG,`v` DOUBLE") .load() df.printSchema() df.show(false) } }
SparkSql自定义HBaseSource
import java.util import java.util.Optional import com.travel.utils.HbaseTools import org.apache.hadoop.hbase.TableName import org.apache.hadoop.hbase.client._ import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} object HBaseSourceAndSink { def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .master("local[2]") .getOrCreate() val df = spark.read .format("com.travel.programApp.HBaseSource") .option("hbase.table.name", "spark_hbase_sql") .option("schema", "`name` STRING,`score` STRING") .option("cf.cc","cf:name,cf:score") .load() df.explain(true) df.createOrReplaceTempView("sparkHBaseSQL") df.printSchema() val frame: DataFrame = spark.sql("select * from sparkHBaseSQL where score > 60") frame.write.format("com.travel.programApp.HBaseSource") .mode(SaveMode.Overwrite) .option("hbase.table.name","spark_hbase_write") .save() } } class HBaseSource extends DataSourceV2 with ReadSupport with WriteSupport{ override def createReader(options: DataSourceOptions): DataSourceReader = { new HBaseDataSourceReader(options.get("hbase.table.name").get(),options.get("schema").get(),options.get("cf.cc").get()) } override def createWriter(jobId: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = { Optional.of(new HBaseDataSourceWrite) } } class HBaseDataSourceWrite extends DataSourceWriter{ override def createWriterFactory(): DataWriterFactory[Row] = { new HBaseDataWriterFactory } override def commit(messages: Array[WriterCommitMessage]): Unit = { } override def abort(messages: Array[WriterCommitMessage]): Unit = { } } class HBaseDataWriterFactory extends DataWriterFactory[Row]{ override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { new HBaseDataWriter } } class HBaseDataWriter extends DataWriter[Row]{ private val conn: Connection = HbaseTools.getHbaseConn private val table: Table = conn.getTable(TableName.valueOf("spark_hbase_write")) override def write(record: Row): Unit = { val name: String = record.getString(0) val score: String = record.getString(1) val put = new Put("0001".getBytes()) put.addColumn("cf".getBytes(),"name".getBytes(),name.getBytes()) put.addColumn("cf".getBytes(),"score".getBytes(),score.getBytes()) table.put(put) } override def commit(): WriterCommitMessage = { table.close() conn.close() null } override def abort(): Unit = { null } } class HBaseDataSourceReader(tableName:String,schema:String,cfcc:String) extends DataSourceReader { //定义HBase的schema private val structType: StructType = StructType.fromDDL(schema) override def readSchema(): StructType = { structType } //返回DataReaderFactory override def createDataReaderFactories(): util.List[DataReaderFactory[Row]] = { import collection.JavaConverters._ Seq( new HBaseReaderFactory(tableName,cfcc).asInstanceOf[DataReaderFactory[Row]] ).asJava } } class HBaseReaderFactory(tableName:String,cfcc:String) extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = { new HBaseReader(tableName,cfcc) } } class HBaseReader(tableName:String,cfcc:String) extends DataReader[Row] { private var hbaseConnection:Connection = null private var resultScanner:ResultScanner = null private var nextResult:Result = null // 获取HBase当中的数 val data: Iterator[Seq[AnyRef]] = getIterator def getIterator: Iterator[Seq[AnyRef]] = { import scala.collection.JavaConverters._ hbaseConnection = HbaseTools.getHbaseConn val table: Table = hbaseConnection.getTable(TableName.valueOf(tableName)) resultScanner = table.getScanner(new Scan()) val iterator: Iterator[Seq[AnyRef]] = resultScanner.iterator().asScala.map(eachResult => { val str: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "name".getBytes())) val score: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "score".getBytes())) Seq(str, score) }) iterator } override def next(): Boolean = { data.hasNext } override def get(): Row = { val seq: Seq[Any] = data.next() Row.fromSeq(seq) } override def close(): Unit = { hbaseConnection.close() } }
sparksql 基于clickHouuse 扩展DataSourceV2源
package com.mengyao.spark.datasourcev2.ext.example1 import java.io.Serializable import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Statement} import java.text.SimpleDateFormat import java.util import java.util.Optional import cn.itcast.logistics.etl.Configure import org.apache.commons.lang3.{StringUtils, SystemUtils} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.{SaveMode, SparkSession} import org.apache.spark.sql.sources.{DataSourceRegister, EqualTo, Filter} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, StreamWriteSupport, WriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{StructType, _} import org.apache.spark.unsafe.types.UTF8String import org.javatuples.Triplet import ru.yandex.clickhouse.domain.ClickHouseDataType import ru.yandex.clickhouse.response.{ClickHouseResultSet, ClickHouseResultSetMetaData} import ru.yandex.clickhouse.settings.ClickHouseProperties import ru.yandex.clickhouse.{ClickHouseConnection, ClickHouseDataSource, ClickHouseStatement} import scala.collection.mutable.ArrayBuffer /** * @ClassName CKTest * @Description 测试ClickHouse的DataSourceV2实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ object CKTest { private val APP_NAME: String = CKTest.getClass.getSimpleName private val master: String = "local[2]" def main(args: Array[String]) { if (SystemUtils.IS_OS_WINDOWS) System.setProperty("hadoop.home.dir", Configure.LOCAL_HADOOP_HOME) val spark = SparkSession.builder() .master(master) .appName(APP_NAME).getOrCreate(); val df = spark.read.format(Configure.SPARK_CLICKHOUSE_FORMAT) .option("driver", Configure.clickhouseDriver) .option("url", Configure.clickhouseUrl) .option("user", Configure.clickhouseUser) .option("password", Configure.clickhousePassword) .option("table", "tbl_address") .option("use_server_time_zone", "false") .option("use_time_zone", "Asia/Shanghai") .option("max_memory_usage", "2000000000") .option("max_bytes_before_external_group_by", "1000000000") .load().coalesce(1) df.show(1000, false) import spark.implicits._ df.where($"id"===328).distinct().coalesce(1).write.format(Configure.SPARK_CLICKHOUSE_FORMAT) .option("driver", Configure.clickhouseDriver) .option("url", Configure.clickhouseUrl) .option("user", Configure.clickhouseUser) .option("password", Configure.clickhousePassword) .option("table", "tbl_address") .option("use_server_time_zone", "false") .option("use_time_zone", "Asia/Shanghai") .option("max_memory_usage", "2000000000") .option("max_bytes_before_external_group_by", "1000000000") .mode(SaveMode.Append) .save(); } } /** * @ClassName ClickHouseDataSourceV2 * @Description 扩展SparkSQL DataSourceV2的ClickHouse数据源实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class ClickHouseDataSourceV2 extends DataSourceV2 with DataSourceRegister with ReadSupport with WriteSupport with StreamWriteSupport { /** 声明ClickHouse数据源的简称,使用方式为spark.read.format("clickhouse")... */ override def shortName(): String = "clickhouse" /** 批处理方式下的数据读取 */ override def createReader(options: DataSourceOptions): DataSourceReader = new CKReader(new CKOptions(options.asMap())) /** 批处理方式下的数据写入 */ override def createWriter(writeUUID: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = Optional.of(new CKWriter(writeUUID, schema, mode, null, new CKOptions(options.asMap()))) /** 流处理方式下的数据写入 */ override def createStreamWriter(queryId: String, schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = new CKWriter(queryId, schema, null, mode, new CKOptions(options.asMap())) } /** * @ClassName CKReader * @Description 基于批处理方式的ClickHouse数据读取(此处只使用1个分区实现) * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKReader(options: CKOptions) extends DataSourceReader { //with SupportsPushDownRequiredColumns with SupportsPushDownFilters { private val customSchema: java.lang.String = options.getCustomSchema private val helper = new CKHelper(options) import collection.JavaConversions._ private val schema = if(StringUtils.isEmpty(customSchema)) { helper.getSparkTableSchema() } else { helper.getSparkTableSchema(new util.LinkedList[String](asJavaCollection(customSchema.split(",")))) } override def readSchema(): StructType = schema override def planInputPartitions(): util.List[InputPartition[InternalRow]] = util.Arrays.asList(new CKInputPartition(schema, options)) } /** * @ClassName CKInputPartition * @Description 基于批处理方式的ClickHouse分区实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKInputPartition(schema: StructType, options: CKOptions) extends InputPartition[InternalRow] { override def createPartitionReader(): InputPartitionReader[InternalRow] = new CKInputPartitionReader(schema, options) } /** * @ClassName CKInputPartitionReader * @Description 基于批处理方式的ClickHouse分区读取数据实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKInputPartitionReader(schema: StructType, options: CKOptions) extends InputPartitionReader[InternalRow] with Logging with Serializable{ val helper = new CKHelper(options) var connection: ClickHouseConnection = null var st: ClickHouseStatement = null var rs: ResultSet = null override def next(): Boolean = { if (null == connection || connection.isClosed && null == st || st.isClosed && null == rs || rs.isClosed){ connection = helper.getConnection st = connection.createStatement() rs = st.executeQuery(helper.getSelectStatement(schema)) println(/**logInfo**/s"初始化ClickHouse连接.") } if(null != rs && !rs.isClosed) rs.next() else false } override def get(): InternalRow = { val fields = schema.fields val length = fields.length val record = new Array[Any](length) for (i <- 0 until length) { val field = fields(i) val name = field.name val dataType = field.dataType try { dataType match { case DataTypes.BooleanType => record(i) = rs.getBoolean(name) case DataTypes.DateType => record(i) = DateTimeUtils.fromJavaDate(rs.getDate(name)) case DataTypes.DoubleType => record(i) = rs.getDouble(name) case DataTypes.FloatType => record(i) = rs.getFloat(name) case DataTypes.IntegerType => record(i) = rs.getInt(name) case DataTypes.LongType => record(i) = rs.getLong(name) case DataTypes.ShortType => record(i) = rs.getShort(name) case DataTypes.StringType => record(i) = UTF8String.fromString(rs.getString(name)) case DataTypes.TimestampType => record(i) = DateTimeUtils.fromJavaTimestamp(rs.getTimestamp(name)) case DataTypes.BinaryType => record(i) = rs.getBytes(name) case DataTypes.NullType => record(i) = StringUtils.EMPTY } } catch { case e: SQLException => logError(e.getStackTrace.mkString("", scala.util.Properties.lineSeparator, scala.util.Properties.lineSeparator)) } } new GenericInternalRow(record) } override def close(): Unit = {helper.closeAll(connection, st, null, rs)} } /** * @ClassName CKWriter * @Description 支持Batch和Stream的数据写实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKWriter(writeUuidOrQueryId: String, schema: StructType, batchMode: SaveMode, streamMode: OutputMode, options: CKOptions) extends StreamWriter { private val isStreamMode:Boolean = if (null!=streamMode&&null==batchMode) true else false override def useCommitCoordinator(): Boolean = true override def onDataWriterCommit(message: WriterCommitMessage): Unit = {} override def createWriterFactory(): DataWriterFactory[InternalRow] = new CKDataWriterFactory(writeUuidOrQueryId, schema, batchMode, streamMode, options) /** Batch writer commit */ override def commit(messages: Array[WriterCommitMessage]): Unit = {} /** Batch writer abort */ override def abort(messages: Array[WriterCommitMessage]): Unit = {} /** Streaming writer commit */ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} /** Streaming writer abort */ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} } /** * @ClassName CKDataWriterFactory * @Description 写数据工厂,用来实例化CKDataWriter * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKDataWriterFactory(writeUUID: String, schema: StructType, batchMode: SaveMode, streamMode: OutputMode, options: CKOptions) extends DataWriterFactory[InternalRow] { override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = new CKDataWriter(writeUUID, schema, batchMode, streamMode, options) } /** * @ClassName CKDataWriter * @Description ClickHouse的数据写实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKDataWriter(writeUUID: String, schema: StructType, batchMode: SaveMode, streamMode: OutputMode, options: CKOptions) extends DataWriter[InternalRow] with Logging with Serializable { val helper = new CKHelper(options) val opType = options.getOpTypeField private val sqls = ArrayBuffer[String]() private val autoCreateTable: Boolean = options.autoCreateTable private val init = if (autoCreateTable) { val createSQL = helper.createTable(options.getFullTable, schema) println(/**logInfo**/s"==== 初始化表SQL:$createSQL") helper.executeUpdate(createSQL) } val fields = schema.fields override def commit(): WriterCommitMessage = { helper.executeUpdateBatch(sqls) val batchSQL = sqls.mkString("\n") // logDebug(batchSQL) println(batchSQL) new WriterCommitMessage{override def toString: String = s"批量插入SQL: $batchSQL"} } override def write(record: InternalRow): Unit = { if(StringUtils.isEmpty(opType)) { throw new RuntimeException("未传入opTypeField字段名称,无法确定数据持久化类型!") } var sqlStr: String = helper.getStatement(options.getFullTable, schema, record) logDebug(s"==== 拼接完成的INSERT SQL语句为:$sqlStr") try { if (StringUtils.isEmpty(sqlStr)) { val msg = "==== 拼接INSERT SQL语句失败,因为该语句为NULL或EMPTY!" logError(msg) throw new RuntimeException(msg) } Thread.sleep(options.getInterval()) // 在流处理模式下操作 if (null == batchMode) { if (streamMode == OutputMode.Append) { sqls += sqlStr // val state = helper.executeUpdate(sqlStr) // println(s"==== 在OutputMode.Append模式下执行:$sqlStr\n状态:$state") } else if(streamMode == OutputMode.Complete) {logError("==== 未实现OutputMode.Complete模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} else if(streamMode == OutputMode.Update) {logError("==== 未实现OutputMode.Update模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} else {logError(s"==== 未知模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} // 在批处理模式下操作 } else { if (batchMode == SaveMode.Append) { sqls += sqlStr //val state = helper.executeUpdate(sqlStr) //println(s"==== 在SaveMode.Append模式下执行:$sqlStr\n状态:$state") } else if(batchMode == SaveMode.Overwrite) {logError("==== 未实现SaveMode.Overwrite模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} else if(batchMode == SaveMode.ErrorIfExists) {logError("==== 未实现SaveMode.ErrorIfExists模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} else if(batchMode == SaveMode.Ignore) {logError("==== 未实现SaveMode.Ignore模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} else {logError(s"==== 未知模式下的写入操作,请在CKDataWriter.write方法中添加相关实现!")} } } catch { case e: Exception => logError(e.getMessage) } } override def abort(): Unit = {} } /** * @ClassName CKOptions * @Description 从SparkSQL中DataSourceOptions中提取适用于ClickHouse的参数(spark.[read/write].options参数) * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKOptions(var originalMap: util.Map[String, String]) extends Logging with Serializable { val DRIVER_KEY: String = "driver" val URL_KEY: String = "url" val USER_KEY: String = "user" val PASSWORD_KEY: String = "password" val DATABASE_KEY: String = "database" val TABLE_KEY: String = "table" val AUTO_CREATE_TABLE = "autoCreateTable".toLowerCase val PATH_KEY = "path" val INTERVAL = "interval" val CUSTOM_SCHEMA_KEY: String = "customSchema".toLowerCase val WHERE_KEY: String = "where" val OP_TYPE_FIELD = "opTypeField".toLowerCase val PRIMARY_KEY = "primaryKey".toLowerCase def getValue[T](key: String, `type`: T): T = (if (originalMap.containsKey(key)) originalMap.get(key) else null).asInstanceOf[T] def getDriver: String = getValue(DRIVER_KEY, new String) def getURL: String = getValue(URL_KEY, new String) def getUser: String = getValue(USER_KEY, new String) def getPassword: String = getValue(PASSWORD_KEY, new String) def getDatabase: String = getValue(DATABASE_KEY, new String) def getTable: String = getValue(TABLE_KEY, new String) def autoCreateTable: Boolean = { originalMap.getOrDefault(AUTO_CREATE_TABLE, "false").toLowerCase match { case "true" => true case "false" => false case _ => false } } def getInterval(): Long = {originalMap.getOrDefault(INTERVAL, "200").toLong} def getPath: String = if(StringUtils.isEmpty(getValue(PATH_KEY, new String))) getTable else getValue(PATH_KEY, new String) def getWhere: String = getValue(WHERE_KEY, new String) def getCustomSchema: String = getValue(CUSTOM_SCHEMA_KEY, new String) def getOpTypeField: String = getValue(OP_TYPE_FIELD, new String) def getPrimaryKey: String = getValue(PRIMARY_KEY, new String) def getFullTable: String = { val database = getDatabase val table = getTable if (StringUtils.isEmpty(database) && !StringUtils.isEmpty(table)) table else if (!StringUtils.isEmpty(database) && !StringUtils.isEmpty(table)) database+"."+table else table } def asMap(): util.Map[String, String] = this.originalMap override def toString: String = originalMap.toString } /** * @ClassName CKHelper * @Description ClickHouse的JDBCHelper实现 * @Created by MengYao * @Date 2020/5/17 16:34 * @Version V1.0 */ class CKHelper(options: CKOptions) extends Logging with Serializable { private val opType: String = options.getOpTypeField private val id: String = options.getPrimaryKey private var connection: ClickHouseConnection = getConnection def getConnection: ClickHouseConnection = { val url = options.getURL val ds = new ClickHouseDataSource(url, new ClickHouseProperties()) ds.getConnection(options.getUser, options.getPassword) } def createTable(table: String, schema: StructType): String = { val cols = ArrayBuffer[String]() for (field <- schema.fields) { val dataType = field.dataType val ckColName = field.name if (ckColName!=opType) { var ckColType = getClickhouseSqlType(dataType) if (!StringUtils.isEmpty(ckColType)) { if (ckColType.toLowerCase=="string") {ckColType="Nullable("+ckColType+")"} } cols += ckColName+" "+ ckColType } } s"CREATE TABLE IF NOT EXISTS $table(${cols.mkString(",")},sign Int8,version UInt64) ENGINE=VersionedCollapsingMergeTree(sign, version) ORDER BY $id" } def getSparkTableSchema(customFields: util.LinkedList[String] = null): StructType = { import collection.JavaConversions._ val list: util.LinkedList[Triplet[String, String, String]] = getCKTableSchema(customFields) var fields = ArrayBuffer[StructField]() for(trp <- list) { fields += StructField(trp.getValue0, getSparkSqlType(trp.getValue1)) } StructType(fields) } private def getFieldValue(fieldName: String, schema: StructType, data:InternalRow): Any = { var flag = true var fieldValue:String = null val fields = schema.fields for(i <- 0 until fields.length if flag) { val field = fields(i) if(fieldName==field.name) { fieldValue = field.dataType match { case DataTypes.BooleanType => if (data.isNullAt(i)) "NULL" else s"${data.getBoolean(i)}" case DataTypes.DoubleType => if (data.isNullAt(i)) "NULL" else s"${data.getDouble(i)}" case DataTypes.FloatType => if (data.isNullAt(i)) "NULL" else s"${data.getFloat(i)}" case DataTypes.IntegerType => if (data.isNullAt(i)) "NULL" else s"${data.getInt(i)}" case DataTypes.LongType => if (data.isNullAt(i)) "NULL" else s"${data.getLong(i)}" case DataTypes.ShortType => if (data.isNullAt(i)) "NULL" else s"${data.getShort(i)}" case DataTypes.StringType => if (data.isNullAt(i)) "NULL" else s"${data.getUTF8String(i).toString.trim}" case DataTypes.DateType => if (data.isNullAt(i)) "NULL" else s"'${new SimpleDateFormat("yyyy-MM-dd").format(new Date(data.get(i, DateType).asInstanceOf[Date].getTime / 1000))}'" case DataTypes.TimestampType => if (data.isNullAt(i)) "NULL" else s"${new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date(data.getLong(i) / 1000))}" case DataTypes.BinaryType => if (data.isNullAt(i)) "NULL" else s"${data.getBinary(i)}" case DataTypes.NullType => "NULL" } flag = false } } fieldValue } def getStatement(table: String, schema: StructType, record: InternalRow): String = { val opTypeValue: String = getFieldValue(opType, schema, record).toString if (opTypeValue.toLowerCase()=="insert") {getInsertStatement(table, schema, record)} else if (opTypeValue.toLowerCase()=="delete") {getUpdateStatement(table, schema, record)} else if (opTypeValue.toLowerCase()=="update") {getDeleteStatement(table, schema, record)} else {""} } def getSelectStatement(schema: StructType):String = { s"SELECT ${schema.fieldNames.mkString(",")} FROM ${options.getFullTable}" } def getInsertStatement(table:String, schema: StructType, data:InternalRow):String = { val fields = schema.fields val names = ArrayBuffer[String]() val values = ArrayBuffer[String]() // 表示DataFrame中的字段与数据库中的字段相同,拼接SQL语句时使用全量字段拼接 if (data.numFields==fields.length) { } else {// 表示DataFrame中的字段与数据库中的字段不同,拼接SQL时需要仅拼接DataFrame中有的字段到SQL中 } for(i <- 0 until fields.length) { val field = fields(i) val fieldType = field.dataType val fieldName = field.name if (fieldName!=opType) { val fieldValue = fieldType match { case DataTypes.BooleanType => if(data.isNullAt(i)) "NULL" else s"${data.getBoolean(i)}" case DataTypes.DoubleType => if(data.isNullAt(i)) "NULL" else s"${data.getDouble(i)}" case DataTypes.FloatType => if(data.isNullAt(i)) "NULL" else s"${data.getFloat(i)}" case DataTypes.IntegerType => if(data.isNullAt(i)) "NULL" else s"${data.getInt(i)}" case DataTypes.LongType => if(data.isNullAt(i)) "NULL" else s"${data.getLong(i)}" case DataTypes.ShortType => if(data.isNullAt(i)) "NULL" else s"${data.getShort(i)}" case DataTypes.StringType => if(data.isNullAt(i)) "NULL" else s"'${data.getUTF8String(i).toString.trim}'" case DataTypes.DateType => if(data.isNullAt(i)) "NULL" else s"'${new SimpleDateFormat("yyyy-MM-dd").format(new Date(data.get(i, DateType).asInstanceOf[Date].getTime/1000))}'" case DataTypes.TimestampType => if(data.isNullAt(i)) "NULL" else s"'${new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date(data.getLong(i)/1000))}'" case DataTypes.BinaryType => if(data.isNullAt(i)) "NULL" else s"${data.getBinary(i)}" case DataTypes.NullType => "NULL" } names += fieldName values += fieldValue } } if (names.length > 0 && values.length > 0) { names += ("sign","version") values += ("1", System.currentTimeMillis().toString) } s"INSERT INTO $table(${names.mkString(",")}) VALUES(${values.mkString(",")})" } def getDeleteStatement(table:String, schema: StructType, data:InternalRow):String = { val fields = schema.fields val primaryKeyFields = if(options.getPrimaryKey.isEmpty) {fields.filter(field => field.name=="id")} else {fields.filter(field => field.name==options.getPrimaryKey)} if (primaryKeyFields.length>0) { val primaryKeyField = primaryKeyFields(0) val primaryKeyValue = getFieldValue(primaryKeyField.name, schema, data) s"ALTER TABLE $table DELETE WHERE ${primaryKeyField.name} = $primaryKeyValue" } else { logError("==== 找不到主键,无法生成删除SQL!") "" } } def getUpdateStatement(table:String, schema: StructType, data:InternalRow):String = { val fields = schema.fields val primaryKeyFields = if(options.getPrimaryKey.isEmpty) {fields.filter(field => field.name=="id")} else {fields.filter(field => field.name==options.getPrimaryKey)} if (primaryKeyFields.length>0) { val primaryKeyField = primaryKeyFields(0) val primaryKeyValue = getFieldValue(primaryKeyField.name, schema, data) val noPrimaryKeyFields = fields.filter(field=>field.name!=primaryKeyField.name) var sets = ArrayBuffer[String]() for(i <- 0 until noPrimaryKeyFields.length) { val noPrimaryKeyField = noPrimaryKeyFields(i) val set = noPrimaryKeyField.name+"="+getFieldValue(noPrimaryKeyField.name, schema, data).toString sets += set } sets.remove(sets.length-1) s"ALTER TABLE $table UPDATE ${sets.mkString(" AND ")} WHERE ${primaryKeyField.name}=$primaryKeyValue" } else { logError("==== 找不到主键,无法生成修改SQL!") "" } } def getCKTableSchema(customFields: util.LinkedList[String] = null): util.LinkedList[Triplet[String, String, String]] = { val fields = new util.LinkedList[Triplet[String, String, String]] var connection: ClickHouseConnection = null var st: ClickHouseStatement = null var rs: ClickHouseResultSet = null var metaData: ClickHouseResultSetMetaData = null try { connection = getConnection st = connection.createStatement val sql = s"SELECT * FROM ${options.getFullTable} WHERE 1=0" rs = st.executeQuery(sql).asInstanceOf[ClickHouseResultSet] metaData = rs.getMetaData.asInstanceOf[ClickHouseResultSetMetaData] val columnCount = metaData.getColumnCount for (i <- 1 to columnCount) { val columnName = metaData.getColumnName(i) val sqlTypeName = metaData.getColumnTypeName(i) val javaTypeName = ClickHouseDataType.fromTypeString(sqlTypeName).getJavaClass.getSimpleName if (null != customFields && customFields.size > 0) { if(fields.contains(columnName)) fields.add(new Triplet(columnName, sqlTypeName, javaTypeName)) } else { fields.add(new Triplet(columnName, sqlTypeName, javaTypeName)) } } } catch { case e: Exception => e.printStackTrace() } finally { closeAll(connection, st, null, rs) } fields } def executeUpdateBatch(sqls: ArrayBuffer[String]): Unit = { // 拼接Batch SQL:VALUES()()... val batchSQL = new StringBuilder() for(i <- 0 until sqls.length) { val line = sqls(i) var offset: Int = 0 if (!StringUtils.isEmpty(line) && line.contains("VALUES")) { val offset = line.indexOf("VALUES") if(i==0) { val prefix = line.substring(0, offset+6) batchSQL.append(prefix) } val suffix = line.substring(offset+6) batchSQL.append(suffix) } } var st: ClickHouseStatement = null; try { if(null==connection||connection.isClosed) {connection = getConnection} st = connection createStatement() st.executeUpdate(batchSQL.toString()) } catch { case e: Exception => logError(s"执行异常:$sqls\n${e.getMessage}") } finally { //closeAll(connection, st) } } def executeUpdate(sql: String): Int = { var state = 0; var st: ClickHouseStatement = null; try { if(null==connection||connection.isClosed) {connection = getConnection} st = connection createStatement() state = st.executeUpdate(sql) } catch { case e: Exception => logError(s"执行异常:$sql\n${e.getMessage}") } finally { //closeAll(connection, st) } state } def close(connection: Connection): Unit = closeAll(connection) def close(st: Statement): Unit = closeAll(null, st, null, null) def close(ps: PreparedStatement): Unit = closeAll(null, null, ps, null) def close(rs: ResultSet): Unit = closeAll(null, null, null, rs) def closeAll(connection: Connection=null, st: Statement=null, ps: PreparedStatement=null, rs: ResultSet=null): Unit = { try { if (rs != null && !rs.isClosed) rs.close() if (ps != null && !ps.isClosed) ps.close() if (st != null && !st.isClosed) st.close() if (connection != null && !connection.isClosed) connection.close() } catch { case e: Exception => e.printStackTrace() } } /** * IntervalYear (Types.INTEGER, Integer.class, true, 19, 0), * IntervalQuarter (Types.INTEGER, Integer.class, true, 19, 0), * IntervalMonth (Types.INTEGER, Integer.class, true, 19, 0), * IntervalWeek (Types.INTEGER, Integer.class, true, 19, 0), * IntervalDay (Types.INTEGER, Integer.class, true, 19, 0), * IntervalHour (Types.INTEGER, Integer.class, true, 19, 0), * IntervalMinute (Types.INTEGER, Integer.class, true, 19, 0), * IntervalSecond (Types.INTEGER, Integer.class, true, 19, 0), * UInt64 (Types.BIGINT, BigInteger.class, false, 19, 0), * UInt32 (Types.INTEGER, Long.class, false, 10, 0), * UInt16 (Types.SMALLINT, Integer.class, false, 5, 0), * UInt8 (Types.TINYINT, Integer.class, false, 3, 0), * Int64 (Types.BIGINT, Long.class, true, 20, 0, "BIGINT"), * Int32 (Types.INTEGER, Integer.class, true, 11, 0, "INTEGER", "INT"), * Int16 (Types.SMALLINT, Integer.class, true, 6, 0, "SMALLINT"), * Int8 (Types.TINYINT, Integer.class, true, 4, 0, "TINYINT"), * Date (Types.DATE, Date.class, false, 10, 0), * DateTime (Types.TIMESTAMP, Timestamp.class, false, 19, 0, "TIMESTAMP"), * Enum8 (Types.VARCHAR, String.class, false, 0, 0), * Enum16 (Types.VARCHAR, String.class, false, 0, 0), * Float32 (Types.FLOAT, Float.class, true, 8, 8, "FLOAT"), * Float64 (Types.DOUBLE, Double.class, true, 17, 17, "DOUBLE"), * Decimal32 (Types.DECIMAL, BigDecimal.class, true, 9, 9), * Decimal64 (Types.DECIMAL, BigDecimal.class, true, 18, 18), * Decimal128 (Types.DECIMAL, BigDecimal.class, true, 38, 38), * Decimal (Types.DECIMAL, BigDecimal.class, true, 0, 0, "DEC"), * UUID (Types.OTHER, UUID.class, false, 36, 0), * String (Types.VARCHAR, String.class, false, 0, 0, "LONGBLOB", "MEDIUMBLOB", "TINYBLOB", "MEDIUMTEXT", "CHAR", "VARCHAR", "TEXT", "TINYTEXT", "LONGTEXT", "BLOB"), * FixedString (Types.CHAR, String.class, false, -1, 0, "BINARY"), * Nothing (Types.NULL, Object.class, false, 0, 0), * Nested (Types.STRUCT, String.class, false, 0, 0), * Tuple (Types.OTHER, String.class, false, 0, 0), * Array (Types.ARRAY, Array.class, false, 0, 0), * AggregateFunction (Types.OTHER, String.class, false, 0, 0), * Unknown (Types.OTHER, String.class, false, 0, 0); * * @param clickhouseDataType * @return */ private def getSparkSqlType(clickhouseDataType: String) = clickhouseDataType match { case "IntervalYear" => DataTypes.IntegerType case "IntervalQuarter" => DataTypes.IntegerType case "IntervalMonth" => DataTypes.IntegerType case "IntervalWeek" => DataTypes.IntegerType case "IntervalDay" => DataTypes.IntegerType case "IntervalHour" => DataTypes.IntegerType case "IntervalMinute" => DataTypes.IntegerType case "IntervalSecond" => DataTypes.IntegerType case "UInt64" => DataTypes.LongType //DataTypes.IntegerType; case "UInt32" => DataTypes.LongType case "UInt16" => DataTypes.IntegerType case "UInt8" => DataTypes.IntegerType case "Int64" => DataTypes.LongType case "Int32" => DataTypes.IntegerType case "Int16" => DataTypes.IntegerType case "Int8" => DataTypes.IntegerType case "Date" => DataTypes.DateType case "DateTime" => DataTypes.TimestampType case "Enum8" => DataTypes.StringType case "Enum16" => DataTypes.StringType case "Float32" => DataTypes.FloatType case "Float64" => DataTypes.DoubleType case "Decimal32" => DataTypes.createDecimalType case "Decimal64" => DataTypes.createDecimalType case "Decimal128" => DataTypes.createDecimalType case "Decimal" => DataTypes.createDecimalType case "UUID" => DataTypes.StringType case "String" => DataTypes.StringType case "FixedString" => DataTypes.StringType case "Nothing" => DataTypes.NullType case "Nested" => DataTypes.StringType case "Tuple" => DataTypes.StringType case "Array" => DataTypes.StringType case "AggregateFunction" => DataTypes.StringType case "Unknown" => DataTypes.StringType case _ => DataTypes.NullType } private def getClickhouseSqlType(sparkDataType: DataType) = sparkDataType match { case DataTypes.ByteType => "Int8" case DataTypes.ShortType => "Int16" case DataTypes.IntegerType => "Int32" case DataTypes.FloatType => "Float32" case DataTypes.DoubleType => "Float64" case DataTypes.LongType => "Int64" case DataTypes.DateType => "DateTime" case DataTypes.TimestampType => "DateTime" case DataTypes.StringType => "String" case DataTypes.NullType => "String" } }
sparksql 扩展nebula源
创建NebulaDataSource类
import java.util.Map.Entry import java.util.Optional import com.vesoft.nebula.connector.exception.IllegalOptionException import com.vesoft.nebula.connector.reader.{NebulaDataSourceEdgeReader, NebulaDataSourceVertexReader} import com.vesoft.nebula.connector.writer.{NebulaDataSourceEdgeWriter, NebulaDataSourceVertexWriter} import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.reader.DataSourceReader import org.apache.spark.sql.sources.v2.writer.DataSourceWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory import scala.collection.JavaConversions.iterableAsScalaIterable class NebulaDataSource extends DataSourceV2 with ReadSupport with WriteSupport with DataSourceRegister { private val LOG = LoggerFactory.getLogger(this.getClass) /** * The string that represents the format that nebula data source provider uses. */ override def shortName(): String = "nebula" /** * Creates a {@link DataSourceReader} to scan the data from Nebula Graph. */ override def createReader(options: DataSourceOptions): DataSourceReader = { val nebulaOptions = getNebulaOptions(options, OperaType.READ) val dataType = nebulaOptions.dataType LOG.info("create reader") LOG.info(s"options ${options.asMap()}") if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { new NebulaDataSourceVertexReader(nebulaOptions) } else { new NebulaDataSourceEdgeReader(nebulaOptions) } } /** * Creates an optional {@link DataSourceWriter} to save the data to Nebula Graph. */ override def createWriter(writeUUID: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = { val nebulaOptions = getNebulaOptions(options, OperaType.WRITE) val dataType = nebulaOptions.dataType if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { LOG.warn(s"Currently do not support mode") } LOG.info("create writer") LOG.info(s"options ${options.asMap()}") if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { val vertexFiled = nebulaOptions.vertexField val vertexIndex: Int = { var index: Int = -1 for (i <- schema.fields.indices) { if (schema.fields(i).name.equals(vertexFiled)) { index = i } } if (index < 0) { throw new IllegalOptionException( s" vertex field ${vertexFiled} does not exist in dataframe") } index } Optional.of(new NebulaDataSourceVertexWriter(nebulaOptions, vertexIndex, schema)) } else { val srcVertexFiled = nebulaOptions.srcVertexField val dstVertexField = nebulaOptions.dstVertexField val rankExist = !nebulaOptions.rankField.isEmpty val edgeFieldsIndex = { var srcIndex: Int = -1 var dstIndex: Int = -1 var rankIndex: Int = -1 for (i <- schema.fields.indices) { if (schema.fields(i).name.equals(srcVertexFiled)) { srcIndex = i } if (schema.fields(i).name.equals(dstVertexField)) { dstIndex = i } if (rankExist) { if (schema.fields(i).name.equals(nebulaOptions.rankField)) { rankIndex = i } } } // check src filed and dst field if (srcIndex < 0 || dstIndex < 0) { throw new IllegalOptionException( s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe") } // check rank field if (rankExist && rankIndex < 0) { throw new IllegalOptionException(s"rank field does not exist in dataframe") } if (!rankExist) { (srcIndex, dstIndex, Option.empty) } else { (srcIndex, dstIndex, Option(rankIndex)) } } Optional.of( new NebulaDataSourceEdgeWriter(nebulaOptions, edgeFieldsIndex._1, edgeFieldsIndex._2, edgeFieldsIndex._3, schema)) } } /** * construct nebula options with DataSourceOptions */ def getNebulaOptions(options: DataSourceOptions, operateType: OperaType.Value): NebulaOptions = { var parameters: Map[String, String] = Map() for (entry: Entry[String, String] <- options.asMap().entrySet) { parameters += (entry.getKey -> entry.getValue) } val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters))(operateType) nebulaOptions } }
创建NebulaSourceReader类
import java.util import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions, NebulaUtils} import com.vesoft.nebula.connector.nebula.MetaProvider import com.vesoft.nebula.meta.ColumnDef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer import scala.collection.JavaConverters._ /** * Base class of Nebula Source Reader */ abstract class NebulaSourceReader(nebulaOptions: NebulaOptions) extends DataSourceReader { private val LOG = LoggerFactory.getLogger(this.getClass) private var datasetSchema: StructType = _ override def readSchema(): StructType = { datasetSchema = getSchema(nebulaOptions) LOG.info(s"dataset's schema: $datasetSchema") datasetSchema } protected def getSchema: StructType = getSchema(nebulaOptions) ## return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula. def getSchema(nebulaOptions: NebulaOptions): StructType = { val returnCols = nebulaOptions.getReturnCols val noColumn = nebulaOptions.noColumn val fields: ListBuffer[StructField] = new ListBuffer[StructField] val metaProvider = new MetaProvider(nebulaOptions.getMetaAddress) import scala.collection.JavaConverters._ var schemaCols: Seq[ColumnDef] = Seq() val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType) // construct vertex or edge default prop if (isVertex) { fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false)) } else { fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false)) fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false)) fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false)) } var dataSchema: StructType = null // read no column if (noColumn) { dataSchema = new StructType(fields.toArray) return dataSchema } // get tag schema or edge schema val schema = if (isVertex) { metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label) } else { metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label) } schemaCols = schema.columns.asScala // read all columns if (returnCols.isEmpty) { schemaCols.foreach(columnDef => { LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ") fields.append( DataTypes.createStructField(new String(columnDef.getName), NebulaUtils.convertDataType(columnDef.getType), true)) }) } else { for (col: String <- returnCols) { fields.append( DataTypes .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true)) } } dataSchema = new StructType(fields.toArray) dataSchema } } /** * DataSourceReader for Nebula Vertex */ class NebulaDataSourceVertexReader(nebulaOptions: NebulaOptions) extends NebulaSourceReader(nebulaOptions) { override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { val partitionNum = nebulaOptions.partitionNums.toInt val partitions = for (index <- 1 to partitionNum) yield { new NebulaVertexPartition(index, nebulaOptions, getSchema) } partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava } } /** * DataSourceReader for Nebula Edge */ class NebulaDataSourceEdgeReader(nebulaOptions: NebulaOptions) extends NebulaSourceReader(nebulaOptions) { override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { val partitionNum = nebulaOptions.partitionNums.toInt val partitions = for (index <- 1 to partitionNum) yield new NebulaEdgePartition(index, nebulaOptions, getSchema) partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava } }
创建NebulaSourceWriter类
import com.vesoft.nebula.connector.NebulaOptions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{ DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage } import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory /** * creating and initializing the actual Nebula vertex writer at executor side */ class NebulaVertexWriterFactory(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) extends DataWriterFactory[InternalRow] { override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { new NebulaVertexWriter(nebulaOptions, vertexIndex, schema) } } /** * creating and initializing the actual Nebula edge writer at executor side */ class NebulaEdgeWriterFactory(nebulaOptions: NebulaOptions, srcIndex: Int, dstIndex: Int, rankIndex: Option[Int], schema: StructType) extends DataWriterFactory[InternalRow] { override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { new NebulaEdgeWriter(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) } } /** * nebula vertex writer to create factory */ class NebulaDataSourceVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) extends DataSourceWriter { private val LOG = LoggerFactory.getLogger(this.getClass) override def createWriterFactory(): DataWriterFactory[InternalRow] = { new NebulaVertexWriterFactory(nebulaOptions, vertexIndex, schema) } override def commit(messages: Array[WriterCommitMessage]): Unit = { LOG.debug(s"${messages.length}") for (msg <- messages) { val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] LOG.info(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") } } override def abort(messages: Array[WriterCommitMessage]): Unit = { LOG.error("NebulaDataSourceVertexWriter abort") } } /** * nebula edge writer to create factory */ class NebulaDataSourceEdgeWriter(nebulaOptions: NebulaOptions, srcIndex: Int, dstIndex: Int, rankIndex: Option[Int], schema: StructType) extends DataSourceWriter { private val LOG = LoggerFactory.getLogger(this.getClass) override def createWriterFactory(): DataWriterFactory[InternalRow] = { new NebulaEdgeWriterFactory(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) } override def commit(messages: Array[WriterCommitMessage]): Unit = { LOG.debug(s"${messages.length}") for (msg <- messages) { val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] LOG.info(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") } } override def abort(messages: Array[WriterCommitMessage]): Unit = { LOG.error("NebulaDataSourceEdgeWriter abort") } }
完整代码参考:
https://github.com/vesoft-inc/nebula-spark-utils/tree/master/nebula-spark-connector/