开发者社区> 问答> 正文

MaxCompute用户指南:图模型:示例程序:K-均值聚类



k-均值聚类(Kmeans) 算法是非常基础并大量使用的聚类算法。
算法基本原理:以空间中 k 个点为中心进行聚类,对最靠近它们的点进行归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
假设要把样本集分为 k 个类别,算法描述如下:


  1. 适当选择 k 个类的初始中心。

  2. 在第 i 次迭代中,对任意一个样本,求其到 k 个中心的距离,将该样本归到距离最短的中心所在的类。

  3. 利用均值等方法更新该类的中心值。

  4. 对于所有的 k 个聚类中心,如果利用上两步的迭代法更新后,值保持不变或者小于某个阈值,则迭代结束,否则继续迭代。


代码示例


K-均值聚类算法的代码,如下所示:
  1. import java.io.DataInput;
  2. import java.io.DataOutput;
  3. import java.io.IOException;
  4. import org.apache.log4j.Logger;
  5. import com.aliyun.odps.io.WritableRecord;
  6. import com.aliyun.odps.graph.Aggregator;
  7. import com.aliyun.odps.graph.ComputeContext;
  8. import com.aliyun.odps.graph.GraphJob;
  9. import com.aliyun.odps.graph.GraphLoader;
  10. import com.aliyun.odps.graph.MutationContext;
  11. import com.aliyun.odps.graph.Vertex;
  12. import com.aliyun.odps.graph.WorkerContext;
  13. import com.aliyun.odps.io.DoubleWritable;
  14. import com.aliyun.odps.io.LongWritable;
  15. import com.aliyun.odps.io.NullWritable;
  16. import com.aliyun.odps.data.TableInfo;
  17. import com.aliyun.odps.io.Text;
  18. import com.aliyun.odps.io.Tuple;
  19. import com.aliyun.odps.io.Writable;
  20. public class Kmeans {
  21.   private final static Logger LOG = Logger.getLogger(Kmeans.class);
  22.   public static class KmeansVertex extends
  23.       Vertex<Text, Tuple, NullWritable, NullWritable> {
  24.     @Override
  25.     public void compute(
  26.         ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
  27.         Iterable<NullWritable> messages) throws IOException {
  28.       context.aggregate(getValue());
  29.     }
  30.   }
  31.   public static class KmeansVertexReader extends
  32.       GraphLoader<Text, Tuple, NullWritable, NullWritable> {
  33.     @Override
  34.     public void load(LongWritable recordNum, WritableRecord record,
  35.         MutationContext<Text, Tuple, NullWritable, NullWritable> context)
  36.         throws IOException {
  37.       KmeansVertex vertex = new KmeansVertex();
  38.       vertex.setId(new Text(String.valueOf(recordNum.get())));
  39.       vertex.setValue(new Tuple(record.getAll()));
  40.       context.addVertexRequest(vertex);
  41.     }
  42.   }
  43.   public static class KmeansAggrValue implements Writable {
  44.     Tuple centers = new Tuple();
  45.     Tuple sums = new Tuple();
  46.     Tuple counts = new Tuple();
  47.     @Override
  48.     public void write(DataOutput out) throws IOException {
  49.       centers.write(out);
  50.       sums.write(out);
  51.       counts.write(out);
  52.     }
  53.     @Override
  54.     public void readFields(DataInput in) throws IOException {
  55.       centers = new Tuple();
  56.       centers.readFields(in);
  57.       sums = new Tuple();
  58.       sums.readFields(in);
  59.       counts = new Tuple();
  60.       counts.readFields(in);
  61.     }
  62.     @Override
  63.     public String toString() {
  64.       return "centers " + centers.toString() + ", sums " + sums.toString()
  65.           + ", counts " + counts.toString();
  66.     }
  67.   }
  68.   public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
  69.     @SuppressWarnings("rawtypes")
  70.     @Override
  71.     public KmeansAggrValue createInitialValue(WorkerContext context)
  72.         throws IOException {
  73.       KmeansAggrValue aggrVal = null;
  74.       if (context.getSuperstep() == 0) {
  75.         aggrVal = new KmeansAggrValue();
  76.         aggrVal.centers = new Tuple();
  77.         aggrVal.sums = new Tuple();
  78.         aggrVal.counts = new Tuple();
  79.         byte[] centers = context.readCacheFile("centers");
  80.         String lines[] = new String(centers).split("\n");
  81.         for (int i = 0; i < lines.length; i++) {
  82.           String[] ss = lines.split(",");
  83.           Tuple center = new Tuple();
  84.           Tuple sum = new Tuple();
  85.           for (int j = 0; j < ss.length; ++j) {
  86.             center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
  87.             sum.append(new DoubleWritable(0.0));
  88.           }
  89.           LongWritable count = new LongWritable(0);
  90.           aggrVal.sums.append(sum);
  91.           aggrVal.counts.append(count);
  92.           aggrVal.centers.append(center);
  93.         }
  94.       } else {
  95.         aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
  96.       }
  97.       return aggrVal;
  98.     }
  99.     @Override
  100.     public void aggregate(KmeansAggrValue value, Object item) {
  101.       int min = 0;
  102.       double mindist = Double.MAX_VALUE;
  103.       Tuple point = (Tuple) item;
  104.       for (int i = 0; i < value.centers.size(); i++) {
  105.         Tuple center = (Tuple) value.centers.get(i);
  106.         // use Euclidean Distance, no need to calculate sqrt
  107.         double dist = 0.0d;
  108.         for (int j = 0; j < center.size(); j++) {
  109.           double v = ((DoubleWritable) point.get(j)).get()
  110.               - ((DoubleWritable) center.get(j)).get();
  111.           dist += v * v;
  112.         }
  113.         if (dist < mindist) {
  114.           mindist = dist;
  115.           min = i;
  116.         }
  117.       }
  118.       // update sum and count
  119.       Tuple sum = (Tuple) value.sums.get(min);
  120.       for (int i = 0; i < point.size(); i++) {
  121.         DoubleWritable s = (DoubleWritable) sum.get(i);
  122.         s.set(s.get() + ((DoubleWritable) point.get(i)).get());
  123.       }
  124.       LongWritable count = (LongWritable) value.counts.get(min);
  125.       count.set(count.get() + 1);
  126.     }
  127.     @Override
  128.     public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
  129.       for (int i = 0; i < value.sums.size(); i++) {
  130.         Tuple sum = (Tuple) value.sums.get(i);
  131.         Tuple that = (Tuple) partial.sums.get(i);
  132.         for (int j = 0; j < sum.size(); j++) {
  133.           DoubleWritable s = (DoubleWritable) sum.get(j);
  134.           s.set(s.get() + ((DoubleWritable) that.get(j)).get());
  135.         }
  136.       }
  137.       for (int i = 0; i < value.counts.size(); i++) {
  138.         LongWritable count = (LongWritable) value.counts.get(i);
  139.         count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
  140.       }
  141.     }
  142.     @SuppressWarnings("rawtypes")
  143.     @Override
  144.     public boolean terminate(WorkerContext context, KmeansAggrValue value)
  145.         throws IOException {
  146.       // compute new centers
  147.       Tuple newCenters = new Tuple(value.sums.size());
  148.       for (int i = 0; i < value.sums.size(); i++) {
  149.         Tuple sum = (Tuple) value.sums.get(i);
  150.         Tuple newCenter = new Tuple(sum.size());
  151.         LongWritable c = (LongWritable) value.counts.get(i);
  152.         for (int j = 0; j < sum.size(); j++) {
  153.           DoubleWritable s = (DoubleWritable) sum.get(j);
  154.           double val = s.get() / c.get();
  155.           newCenter.set(j, new DoubleWritable(val));
  156.           // reset sum for next iteration
  157.           s.set(0.0d);
  158.         }
  159.         // reset count for next iteration
  160.         c.set(0);
  161.         newCenters.set(i, newCenter);
  162.       }
  163.       // update centers
  164.       Tuple oldCenters = value.centers;
  165.       value.centers = newCenters;
  166.       LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
  167.       // compare new/old centers
  168.       boolean converged = true;
  169.       for (int i = 0; i < value.centers.size() && converged; i++) {
  170.         Tuple oldCenter = (Tuple) oldCenters.get(i);
  171.         Tuple newCenter = (Tuple) newCenters.get(i);
  172.         double sum = 0.0d;
  173.         for (int j = 0; j < newCenter.size(); j++) {
  174.           double v = ((DoubleWritable) newCenter.get(j)).get()
  175.               - ((DoubleWritable) oldCenter.get(j)).get();
  176.           sum += v * v;
  177.         }
  178.         double dist = Math.sqrt(sum);
  179.         LOG.info("old center: " + oldCenter + ", new center: " + newCenter
  180.             + ", dist: " + dist);
  181.         // converge threshold for each center: 0.05
  182.         converged = dist < 0.05d;
  183.       }
  184.       if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
  185.         // converged or reach max iteration, output centers
  186.         for (int i = 0; i < value.centers.size(); i++) {
  187.           context.write(((Tuple) value.centers.get(i)).toArray());
  188.         }
  189.         // true means to terminate iteration
  190.         return true;
  191.       }
  192.       // false means to continue iteration
  193.       return false;
  194.     }
  195.   }
  196.   private static void printUsage() {
  197.     System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
  198.     System.exit(-1);
  199.   }
  200.   public static void main(String[] args) throws IOException {
  201.     if (args.length < 2)
  202.       printUsage();
  203.     GraphJob job = new GraphJob();
  204.     job.setGraphLoaderClass(KmeansVertexReader.class);
  205.     job.setRuntimePartitioning(false);
  206.     job.setVertexClass(KmeansVertex.class);
  207.     job.setAggregatorClass(KmeansAggregator.class);
  208.     job.addInput(TableInfo.builder().tableName(args[0]).build());
  209.     job.addOutput(TableInfo.builder().tableName(args[1]).build());
  210.     // default max iteration is 30
  211.     job.setMaxIteration(30);
  212.     if (args.length >= 3)
  213.       job.setMaxIteration(Integer.parseInt(args[2]));
  214.     long start = System.currentTimeMillis();
  215.     job.run();
  216.     System.out.println("Job Finished in "
  217.         + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  218.   }
  219. }

上述代码,说明如下:

  • 第 26 行:定义 KmeansVertex,compute() 方法非常简单,只是调用上下文对象的 aggregate 方法,传入当前点的取值(Tuple 类型,向量表示)。

  • 第 38 行:定义 KmeansVertexReader 类,加载图,将表中每一条记录解析为一个点,点标识无关紧要,这里取传入的 recordNum 序号作为标识,点值为记录的所有列组成的 Tuple。

  • 第 83 行:定义 KmeansAggregator,这个类封装了 Kmeans 算法的主要逻辑,其中:
    createInitialValue 为每一轮迭代创建初始值(k 类中心点),若是第一轮迭代(superstep=0),该取值为初始中心点,否则取值为上一轮结束时的新中心点。

  • aggregate 方法为每个点计算其到各个类中心的距离,并归为距离最短的类,并更新该类的 sum 和 count。

  • merge 方法合并来自各个 worker 收集的 sum 和 count。

  • terminate 方法根据各个类的 sum 和 count 计算新的中心点,若新中心点与之前的中心点距离小于某个阈值或者迭代次数到达最大迭代次数设置,则终止迭代(返回 false),写最终的中心点到结果表。

第 236 行:主程序(main 函数),定义 GraphJob,指定 Vertex/GraphLoader/Aggregator 等的实现,以及最大迭代次数(默认 30),并指定输入输出表。
第 243 行:job.setRuntimePartitioning(false),对于 Kmeans 算法,加载图是不需要进行点的分发,设置 RuntimePartitioning 为 false,以提升加载图时的性能。

展开
收起
行者武松 2017-10-24 10:27:57 2706 0
0 条回答
写回答
取消 提交回答
问答排行榜
最热
最新

相关电子书

更多
大数据AI一体化的解读 立即下载
极氪大数据 Serverless 应用实践 立即下载
大数据&AI实战派 第2期 立即下载