开发者社区> 问答> 正文

MaxCompute用户指南:图模型:示例程序:单源最短距离



Dijkstra 算法是求解有向图中单源最短距离(Single Source Shortest Path,简称为 SSSP)的经典算法。
最短距离:对一个有权重的有向图 G=(V,E),从一个源点 s 到汇点 v 有很多路径,其中边权和最小的路径,称从 s 到 v 的最短距离。
算法基本原理,如下所示:


  • 初始化:源点 s 到 s 自身的距离(d[s]=0),其他点 u 到 s 的距离为无穷(d=∞)。

  • 迭代:若存在一条从 u 到 v 的边,那么从 s 到 v 的最短距离更新为:d[v]=min(d[v], d+weight(u, v)),直到所有的点到 s 的距离不再发生变化时,迭代结束。

由算法基本原理可以看出,此算法非常适合使用 MaxCompute Graph程序进行求解:每个点维护到源点的当前最短距离值,当这个值变化时,将新值加上边的权值发送消息通知其邻接点,下一轮迭代时,邻接点根据收到的消息更新其当前最短距离,当所有点当前最短距离不再变化时,迭代结束。

代码示例


单源最短距离的代码,如下所示:
  1. import java.io.IOException;
  2. import com.aliyun.odps.io.WritableRecord;
  3. import com.aliyun.odps.graph.Combiner;
  4. import com.aliyun.odps.graph.ComputeContext;
  5. import com.aliyun.odps.graph.Edge;
  6. import com.aliyun.odps.graph.GraphJob;
  7. import com.aliyun.odps.graph.GraphLoader;
  8. import com.aliyun.odps.graph.MutationContext;
  9. import com.aliyun.odps.graph.Vertex;
  10. import com.aliyun.odps.graph.WorkerContext;
  11. import com.aliyun.odps.io.LongWritable;
  12. import com.aliyun.odps.data.TableInfo;
  13. public class SSSP {
  14.   public static final String START_VERTEX = "sssp.start.vertex.id";
  15.   public static class SSSPVertex extends
  16.       Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
  17.     private static long startVertexId = -1;
  18.     public SSSPVertex() {
  19.       this.setValue(new LongWritable(Long.MAX_VALUE));
  20.     }
  21.     public boolean isStartVertex(
  22.         ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) {
  23.       if (startVertexId == -1) {
  24.         String s = context.getConfiguration().get(START_VERTEX);
  25.         startVertexId = Long.parseLong(s);
  26.       }
  27.       return getId().get() == startVertexId;
  28.     }
  29.     @Override
  30.     public void compute(
  31.         ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
  32.         Iterable<LongWritable> messages) throws IOException {
  33.       long minDist = isStartVertex(context) ? 0 : Integer.MAX_VALUE;
  34.       for (LongWritable msg : messages) {
  35.         if (msg.get() < minDist) {
  36.           minDist = msg.get();
  37.         }
  38.       }
  39.       if (minDist < this.getValue().get()) {
  40.         this.setValue(new LongWritable(minDist));
  41.         if (hasEdges()) {
  42.           for (Edge<LongWritable, LongWritable> e : this.getEdges()) {
  43.             context.sendMessage(e.getDestVertexId(), new LongWritable(minDist
  44.                 + e.getValue().get()));
  45.           }
  46.         }
  47.       } else {
  48.         voteToHalt();
  49.       }
  50.     }
  51.     @Override
  52.     public void cleanup(
  53.         WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
  54.         throws IOException {
  55.       context.write(getId(), getValue());
  56.     }
  57.   }
  58.   public static class MinLongCombiner extends
  59.       Combiner<LongWritable, LongWritable> {
  60.     @Override
  61.     public void combine(LongWritable vertexId, LongWritable combinedMessage,
  62.         LongWritable messageToCombine) throws IOException {
  63.       if (combinedMessage.get() > messageToCombine.get()) {
  64.         combinedMessage.set(messageToCombine.get());
  65.       }
  66.     }
  67.   }
  68.   public static class SSSPVertexReader extends
  69.       GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
  70.     @Override
  71.     public void load(
  72.         LongWritable recordNum,
  73.         WritableRecord record,
  74.         MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
  75.         throws IOException {
  76.       SSSPVertex vertex = new SSSPVertex();
  77.       vertex.setId((LongWritable) record.get(0));
  78.       String[] edges = record.get(1).toString().split(",");
  79.       for (int i = 0; i < edges.length; i++) {
  80.         String[] ss = edges.split(":");
  81.         vertex.addEdge(new LongWritable(Long.parseLong(ss[0])),
  82.             new LongWritable(Long.parseLong(ss[1])));
  83.       }
  84.       context.addVertexRequest(vertex);
  85.     }
  86.   }
  87.   public static void main(String[] args) throws IOException {
  88.     if (args.length < 2) {
  89.       System.out.println("Usage: <startnode> <input> <output>");
  90.       System.exit(-1);
  91.     }
  92.     GraphJob job = new GraphJob();
  93.     job.setGraphLoaderClass(SSSPVertexReader.class);
  94.     job.setVertexClass(SSSPVertex.class);
  95.     job.setCombinerClass(MinLongCombiner.class);
  96.     job.set(START_VERTEX, args[0]);
  97.     job.addInput(TableInfo.builder().tableName(args[1]).build());
  98.     job.addOutput(TableInfo.builder().tableName(args[2]).build());
  99.     long startTime = System.currentTimeMillis();
  100.     job.run();
  101.     System.out.println("Job Finished in "
  102.         + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
  103.   }
  104. }

上述代码,说明如下:

  • 第 19 行:定义 SSSPVertex ,其中:
    点值表示该点到源点 startVertexId 的当前最短距离。

  • compute() 方法使用迭代公式:d[v]=min(d[v], d+weight(u, v)) 更新点值。

  • cleanup() 方法把点及其到源点的最短距离写到结果表中。

第 58 行:当点值没发生变化时,调用 voteToHalt() 告诉框架该点进入 halt 状态,当所有点都进入 halt 状态时,计算结束。
第 70 行:定义 MinLongCombiner,对发送给同一个点的消息进行合并,优化性能,减少内存占用。
第 83 行:定义 SSSPVertexReader 类,加载图,将表中每一条记录解析为一个点,记录的第一列是点标识,第二列存储该点起始的所有的边集,内容如:2:2,3:1,4:4。
第 106 行:主程序(main 函数),定义 GraphJob,指定 Vertex/GraphLoader/Combiner 等的实现,指定输入输出表。

展开
收起
行者武松 2017-10-24 10:26:44 1587 0
0 条回答
写回答
取消 提交回答
问答排行榜
最热
最新

相关电子书

更多
大数据AI一体化的解读 立即下载
极氪大数据 Serverless 应用实践 立即下载
大数据&AI实战派 第2期 立即下载