# 基于MaxCompute的图计算实践分享-Aggregator机制介绍

### 更多精彩内容参见云栖社区大数据频道https://yq.aliyun.com/big-data，此外，通过Maxcompute及其配套产品，大数据分析仅需几步，详情访问https://www.aliyun.com/product/odps。

Aggregator是MaxCompute-GRAPH作业中常用的feature之一，特别是解决机器学习问题时。MaxCompute-GRAPH中Aggregator用于汇总并处理全局信息。本文将详细介绍的Aggregator的执行机制、相关API，并以Kmeans Clustering为例子说明Aggregator的具体用法。

## Aggregator的API

Aggregator共提供了五个API供用户实现。下面逐个介绍5个API的调用时机及常规用途。

##### 5. terminate(context, value)

terminate()方法的一个特殊之处在于，如果返回true，则整个作业就结束迭代，否则继续执行。在机器学习场景中，通常判断收敛后返回true以结束作业。

## Kmeans Clustering示例

  public static class KmeansValue implements Writable {

DenseVector sample;

public KmeansValue() {
}

public KmeansValue(DenseVector v) {
this.sample = v;
}

@Override
public void write(DataOutput out) throws IOException {
wirteForDenseVector(out, sample);

}

@Override
public void readFields(DataInput in) throws IOException {
}
}


  public static class KmeansReader extends

@Override
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, KmeansValue, NullWritable, NullWritable> context)
throws IOException {
KmeansVertex v = new KmeansVertex();
v.setId(recordNum);

int n = record.size();
DenseVector dv = new DenseVector(n);
for (int i = 0; i < n; i++) {
dv.set(i, ((DoubleWritable)record.get(i)).get());
}
v.setValue(new KmeansValue(dv));

}
}


##### 2. Vertex部分

  public static class KmeansVertex extends
Vertex<LongWritable, KmeansValue, NullWritable, NullWritable> {

@Override
public void compute(
ComputeContext<LongWritable, KmeansValue, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}

##### 3. Aggregator部分

  public static class KmeansAggrValue implements Writable {

DenseMatrix centroids;
DenseMatrix sums; // used to recalculate new centroids
DenseVector counts; // used to recalculate new centroids

@Override
public void write(DataOutput out) throws IOException {
wirteForDenseDenseMatrix(out, centroids);
wirteForDenseDenseMatrix(out, sums);
wirteForDenseVector(out, counts);
}

@Override
public void readFields(DataInput in) throws IOException {
}
}


KmeansAggrValue中维护了三个对象，其中centroids是当前的K个中心点，假定样本是m维的话，centroids就是一个K*m的矩阵。sums是和centroids大小一样的矩阵，每个元素记录了到特定中心点最近的样本特定维之和，例如sums(i,j)是到第i个中心点最近的样本的第j维度之和。
counts是个K维的向量，记录到每个中心点距离最短的样本个数。sums和counts一起用以计算新的中心点，也是要聚合的主要内容。

  public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {

public KmeansAggrValue createStartupValue(WorkerContext context) throws IOException {
KmeansAggrValue av = new KmeansAggrValue();

String lines[] = new String(centers).split("\n");

int rows = lines.length;
int cols = lines[0].split(",").length; // assumption rows >= 1

av.centroids = new DenseMatrix(rows, cols);
av.sums = new DenseMatrix(rows, cols);
av.sums.zero();
av.counts = new DenseVector(rows);
av.counts.zero();

for (int i = 0; i < lines.length; i++) {
String[] ss = lines[i].split(",");
for (int j = 0; j < ss.length; j++) {
av.centroids.set(i, j, Double.valueOf(ss[j]));
}
}
return av;
}


    @Override
public KmeansAggrValue createInitialValue(WorkerContext context)
throws IOException {
KmeansAggrValue av = (KmeansAggrValue)context.getLastAggregatedValue(0);

// reset for next iteration
av.sums.zero();
av.counts.zero();

return av;
}


    @Override
public void aggregate(KmeansAggrValue value, Object item)
throws IOException {
DenseVector sample = ((KmeansValue)item).sample;

// find the nearest centroid
int min = findNearestCentroid(value.centroids, sample);

// update sum and count
for (int i = 0; i < sample.size(); i ++) {
}
}


    @Override
public void merge(KmeansAggrValue value, KmeansAggrValue partial)
throws IOException {
}


merge的实现逻辑很简单，就是把各个worker聚合出的sums和counts相加即可。

   @Override
public boolean terminate(WorkerContext context, KmeansAggrValue value)
throws IOException {
// Calculate the new means to be the centroids (original sums)
DenseMatrix newCentriods = calculateNewCentroids(value.sums, value.counts, value.centroids);

// print old centroids and new centroids for debugging
System.out.println("\nsuperstep: " + context.getSuperstep() +
"\nold centriod:\n" + value.centroids + " new centriod:\n" + newCentriods);

boolean converged = isConverged(newCentriods, value.centroids, 0.05d);
System.out.println("superstep: " + context.getSuperstep() + "/"
+ (context.getMaxIteration() - 1) + " converged: " + converged);
if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
// converged or reach max iteration, output centriods
for (int i = 0; i < newCentriods.numRows(); i++) {
Writable[] centriod = new Writable[newCentriods.numColumns()];
for (int j = 0; j < newCentriods.numColumns(); j++) {
centriod[j] = new DoubleWritable(newCentriods.get(i, j));
}
context.write(centriod);
}

// true means to terminate iteration
return true;
}

// update centriods
value.centroids.set(newCentriods);
// false means to continue iteration
return false;
}


teminate()中首先根据sums和counts调用calculateNewCentroids()求平均计算出新的中心点。然后调用isConverged()根据新老中心点欧拉距离判断是否已经收敛。如果收敛或迭代次数达到最大数，则将新的中心点输出并返回true，以结束迭代。否则更新中心点并返回false以继续迭代。其中calculateNewCentroids()和isConverged()的实现见附件。

##### 4. main方法

main方法用以构造GraphJob，然后设置相应配置，并提交作业。代码如下：

  public static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();

GraphJob job = new GraphJob();

job.setRuntimePartitioning(false);
job.setVertexClass(KmeansVertex.class);
job.setAggregatorClass(KmeansAggregator.class);

// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));

long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}


## 总结

1）每个worker启动时执行createStartupValue用以创建AggregatorValue；
2）每轮迭代开始前，每个worker执行createInitialValue来初始化本轮的AggregatorValue；
3）一轮迭代中每个点通过context.aggregate()来执行aggregate()实现worker内的局部迭代；
4）每个Worker将局部迭代结果发送给AggregatorOwner所在的Worker；
5）AggregatorOwner所在worker执行多次merge，实现全局聚合；
6）AggregatorOwner所在Worker执行terminate用以对全局聚合结果做处理并决定是否结束迭代。

+ 订阅