ML之HierarchicalClustering:自定义HierarchicalClustering层次聚类算法-阿里云开发者社区

开发者社区> 开发与运维> 正文
登录阅读全文

ML之HierarchicalClustering:自定义HierarchicalClustering层次聚类算法

简介: ML之HierarchicalClustering:自定义HierarchicalClustering层次聚类算法

输出结果

更新中

实现代码

# -*- encoding=utf-8 -*-

from numpy import *

class cluster_node:  #定义cluster_node类,类似Java中的构造函数

   def __init__(self,vec,left=None,right=None,distance=0.0,id=None,count=1):

       self.left=left  

       self.right=right

       self.vec=vec

       self.id=id

       self.distance=distance

       self.count=count #only used for weighted average

def L2dist(v1,v2):  

   return sqrt(sum((v1-v2)**2))

   

def L1dist(v1,v2):  

   return sum(abs(v1-v2))

def hcluster(features,distance=L2dist):

   #cluster the rows of the "features" matrix

   distances={}    

   currentclustid=-1

   # clusters are initially just the individual rows

   clust=[cluster_node(array(features[i]),id=i) for i in range(len(features))]

   while len(clust)>1:  

       lowestpair=(0,1)

       closest=distance(clust[0].vec,clust[1].vec)

   

       for i in range(len(clust)):

           for j in range(i+1,len(clust)):

               # distances is the cache of distance calculations

               if (clust[i].id,clust[j].id) not in distances:

                   distances[(clust[i].id,clust[j].id)]=distance(clust[i].vec,clust[j].vec)

       

               d=distances[(clust[i].id,clust[j].id)]  

       

               if d<closest:  

                   closest=d

                   lowestpair=(i,j)

       

       mergevec=[(clust[lowestpair[0]].vec[i]+clust[lowestpair[1]].vec[i])/2.0 \

           for i in range(len(clust[0].vec))]

       

       newcluster=cluster_node(array(mergevec),left=clust[lowestpair[0]],

                            right=clust[lowestpair[1]],

                            distance=closest,id=currentclustid)

       

       currentclustid-=1  

       del clust[lowestpair[1]]

       del clust[lowestpair[0]]

       clust.append(newcluster)

   return clust[0]

def extract_clusters(clust,dist):  #(clust上边的树形结构,dist阈值)

   # extract list of sub-tree clusters from hcluster tree with distance<dist

   clusters = {}

   if clust.distance<dist:

       # we have found a cluster subtree

       return [clust]

   else:

       # check the right and left branches

       cl = []  

       cr = []

       if clust.left!=None:  

           cl = extract_clusters(clust.left,dist=dist)

       if clust.right!=None:

           cr = extract_clusters(clust.right,dist=dist)

       return cl+cr  

       

def get_cluster_elements(clust):  #用于取出算好聚类的元素

   # return ids for elements in a cluster sub-tree

   if clust.id>=0:  

       # positive id means that this is a leaf

       return [clust.id]

   else:

       # check the right and left branches

       cl = []

       cr = []

       if clust.left!=None:

           cl = get_cluster_elements(clust.left)

       if clust.right!=None:

           cr = get_cluster_elements(clust.right)

       return cl+cr

def printclust(clust,labels=None,n=0):  #将值打印出来

   # indent to make a hierarchy layout

   for i in range(n): print (' '),

   if clust.id<0:

       # negative id means that this is branch

       print ('-')

   else:          

       # positive id means that this is an endpoint

       if labels==None: print (clust.id)

       else: print (labels[clust.id])

   

   if clust.left!=None: printclust(clust.left,labels=labels,n=n+1)

   if clust.right!=None: printclust(clust.right,labels=labels,n=n+1)

def getheight(clust):  #树的高度,递归方法

   # Is this an endpoint? Then the height is just 1

   if clust.left==None and clust.right==None: return 1

   

   # Otherwise the height is the same of the heights of

   # each branch

   return getheight(clust.left)+getheight(clust.right)

def getdepth(clust):   #树的深度,递归方法

   if clust.left==None and clust.right==None: return 0

   

   return max(getdepth(clust.left),getdepth(clust.right))+clust.distance

     

     


版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

分享:
开发与运维
使用钉钉扫一扫加入圈子
+ 订阅

集结各类场景实战经验,助你开发运维畅行无忧

其他文章
最新文章
相关文章