Spark编写Scala实现CF算法
UI矩阵–>II矩阵–>排序
package spark.example import org.apache.spark._ import SparkContext._ import scala.collection.mutable.ArrayBuffer import scala.math._ object CollaborativeFiltering { def main(args: Array[String]) { if (args.length != 5) { System.err.println("Usage: spark.example.CollaborativeFiltering <1:input> <2:output> <3:topn> <4:max_prefs_per_user> <5:score_threshold>") System.exit(1) } val conf = new SparkConf().setAppName("CollaborativeFiltering") conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") val sc = new SparkContext(conf) val lines = sc.textFile(args(0)) val output_path = args(1).toString val topn = args(2).toInt val max_prefs_per_user = args(3).toInt val score_threshold = args(4).toDouble /* * Step 1. * Obtain UI Matrix: */ val ui_rdd = lines.map { x => val fields = x.split(" ") (fields(0).toString, (fields(1).toString, fields(2).toDouble)) }.filter { x => x._2._2 > score_threshold }.groupByKey(4).flatMap { x => val user = x._1 val is_list = x._2 val is_arr = is_list.toArray var is_arr_len = is_arr.length if (is_arr_len > max_prefs_per_user) { is_arr_len = max_prefs_per_user } val i_us_arr = ArrayBuffer[(String, (String, Double))]() for (i <- 0 until is_arr_len) { i_us_arr += ((is_arr(i)._1, (user, is_arr(i)._2))) } i_us_arr }.groupByKey().flatMap { x => val item = x._1 val u_list = x._2 val us_arr = u_list.toArray var sum: Double = 0 for (i <- 0 until us_arr.length) { sum += pow(us_arr(i)._2, 2) } sum = sqrt(sum) val u_is_arr = ArrayBuffer[(String, (String, Double))]() for (i <- 0 until us_arr.length) { u_is_arr += ((us_arr(i)._1, (item, us_arr(i)._2 / sum))) } u_is_arr }.groupByKey().cache() /* * Step 2. * Obtain II Matrix: */ val ii_rdd = ui_rdd.flatMap { x => val is_arr = x._2.toArray.sortBy(_._1) val ii_s_arr = ArrayBuffer[((String, String), Double)]() for (i <- 0 until is_arr.length) { for (j <- (i + 1) until is_arr.length) { ii_s_arr += (((is_arr(i)._1, is_arr(j)._1), is_arr(i)._2 * is_arr(j)._2)) } } ii_s_arr }.groupByKey().map { x => val ii_pair = x._1 val s_list = x._2 val s_arr = s_list.toArray val len = s_arr.length var s:Double = 0.0 for (i <- 0 until len) { s += s_arr(i) } (ii_pair._1, (ii_pair._2, s)) }.flatMap { x => val arr = ArrayBuffer[(String, (String, Double))]() arr += ((x._1, (x._2._1, x._2._2))) arr += ((x._2._1, (x._1, x._2._2))) arr }.groupByKey().map { x => val bs_list = x._2 val bs_arr = bs_list.toArray.sortWith(_._2 > _._2) var l = bs_arr.length if (l > topn) { l = topn } val s = new StringBuilder for (i <- 0 until l) { val score = "%1.8f" format bs_arr(i)._2 val tmp_s = bs_arr(i)._1 + ":" + score s.append(tmp_s) if (i != (l - 1)) { s.append(",") } } x._1 + "\t" + "\t" + s }.saveAsTextFile(output_path) } }
输出结果: