决策树ID3算法和C4.5算法实战

简介: 决策树ID3算法和C4.5算法实战

老师给的题目:

代码实现【两种算法合在一个文件里】:

1. from numpy import *
2. 
3. def createDataSet():
4.     dataSet = [[1, 1, 1, 0, 'no'],
5.                [1, 1, 1, 1, 'no'],
6.                [0, 1, 1, 0, 'yes'],
7.                [-1, 0, 1, 0, 'yes'],
8.                [-1,-1,0,0,'yes'],
9.                [-1,-1,0,1,'no'],
10.                [0,-1,0,1,'yes'],
11.                [1,0,1,0,'no'],
12.                [1,-1,0,0,'yes'],
13.                [-1,0,0,0,'yes'],
14.                [1,0,0,1,'yes'],
15.                [0,0,1,1,'yes'],
16.                [0,1,0,0,'yes'],
17.                [-1,0,1,1,'no']]
18.     labels = ['weather','temperature','humidity','wind speed','activity']
19.     return dataSet, labels
20. 
21. #计算数据集的entropy
22. def calcEntropy(dataSet):
23.     totalNum = len(dataSet)
24.     labelNum = {}
25.     entropy = 0
26.     for data in dataSet:
27.         label = data[-1]
28.         if label in labelNum:
29.             labelNum[label] += 1
30.         else:
31.             labelNum[label] = 1
32. 
33.     for key in labelNum:
34.         p = labelNum[key] / totalNum
35.         entropy -= p * log2(p)
36.     return entropy
37. 
38. def calcEntropyForFeature(featureList):
39.     totalNum = len(featureList)
40.     dataNum = {}
41.     entropy = 0
42.     for data in featureList:
43.         if data in dataNum:
44.             dataNum[data] += 1
45.         else:
46.             dataNum[data] = 1
47. 
48.     for key in dataNum:
49.         p = dataNum[key] / totalNum
50.         entropy -= p * log2(p)
51.     return entropy
52. 
53. #选择最优划分属性ID3
54. def chooseBestFeatureID3(dataSet, labels):
55.     bestFeature = 0
56.     initialEntropy = calcEntropy(dataSet)
57.     biggestEntropyG = 0
58.     for i in range(len(labels)):
59.         currentEntropy = 0
60.         feature = [data[i] for data in dataSet]
61.         subSet = splitDataSetByFeature(i, dataSet)
62.         totalN = len(feature)
63.         for key in subSet:
64.             prob = len(subSet[key]) / totalN
65.             currentEntropy += prob * calcEntropy(subSet[key])
66.         entropyGain = initialEntropy - currentEntropy
67.         if(biggestEntropyG < entropyGain):
68.             biggestEntropyG = entropyGain
69.             bestFeature = i
70.     return bestFeature
71. 
72. #选择最优划分属性C4.5
73. def chooseBestFeatureC45(dataSet, labels):
74.     bestFeature = 0
75.     initialEntropy = calcEntropy(dataSet)
76.     biggestEntropyGR = 0
77.     for i in range(len(labels)):
78.         currentEntropy = 0
79.         feature = [data[i] for data in dataSet]
80.         entropyFeature = calcEntropyForFeature(feature)
81.         subSet = splitDataSetByFeature(i, dataSet)
82.         totalN = len(feature)
83.         for key in subSet:
84.             prob = len(subSet[key]) / totalN
85.             currentEntropy += prob * calcEntropy(subSet[key])
86.         entropyGain = initialEntropy - currentEntropy
87.         entropyGainRatio = entropyGain / entropyFeature
88. 
89.         if(biggestEntropyGR < entropyGainRatio):
90.             biggestEntropyGR = entropyGainRatio
91.             bestFeature = i
92.     return bestFeature
93. 
94. def splitDataSetByFeature(i, dataSet):
95.     subSet = {}
96.     feature = [data[i] for data in dataSet]
97.     for j in range(len(feature)):
98.         if feature[j] not in subSet:
99.             subSet[feature[j]] = []
100. 
101.         splittedDataSet = dataSet[j][:i]
102.         splittedDataSet.extend(dataSet[j][i + 1:])
103.         subSet[feature[j]].append(splittedDataSet)
104.     return subSet
105. 
106. def checkIsOneCateg(newDataSet):
107.     flag = False
108.     categoryList = [data[-1] for data in newDataSet]
109.     category = set(categoryList)
110.     if(len(category) == 1):
111.         flag = True
112.     return flag
113. 
114. def majorityCateg(newDataSet):
115.     categCount = {}
116.     categList = [data[-1] for data in newDataSet]
117.     for c in categList:
118.         if c not in categCount:
119.             categCount[c] = 1
120.         else:
121.             categCount[c] += 1
122.     sortedCateg = sorted(categCount.items(), key = lambda x:x[1], reverse = True)
123. 
124.     return sortedCateg[0][0]
125. 
126. #创建ID3树
127. def createDecisionTreeID3(decisionTree, dataSet, tmplabels):
128.     labels=[]
129.     for tmp in tmplabels:
130.         labels.append(tmp)
131.     bestFeature = chooseBestFeatureID3(dataSet, labels)
132.     decisionTree[labels[bestFeature]] = {}
133.     currentLabel = labels[bestFeature]
134.     subSet = splitDataSetByFeature(bestFeature, dataSet)
135.     del(labels[bestFeature])
136.     newLabels = labels[:]
137.     for key in subSet:
138.         newDataSet = subSet[key]
139.         flag = checkIsOneCateg(newDataSet)
140.         if(flag == True):
141.             decisionTree[currentLabel][key] = newDataSet[0][-1]
142.         else:
143.             if (len(newDataSet[0]) == 1): #无特征值可划分
144.                 decisionTree[currentLabel][key] = majorityCateg(newDataSet)
145.             else:
146.                 decisionTree[currentLabel][key] = {}
147.                 createDecisionTreeID3(decisionTree[currentLabel][key], newDataSet, newLabels)
148. 
149. # 创建C4.5树
150. def createDecisionTreeC45(decisionTree, dataSet, tmplabels):
151.     labels=[]
152.     for tmp in tmplabels:
153.         labels.append(tmp)
154.     bestFeature = chooseBestFeatureC45(dataSet, labels)
155.     decisionTree[labels[bestFeature]] = {}
156.     currentLabel = labels[bestFeature]
157.     subSet = splitDataSetByFeature(bestFeature, dataSet)
158.     del (labels[bestFeature])
159.     newLabels = labels[:]
160.     for key in subSet:
161.         newDataSet = subSet[key]
162.         flag = checkIsOneCateg(newDataSet)
163.         if (flag == True):
164.             decisionTree[currentLabel][key] = newDataSet[0][-1]
165.         else:
166.             if (len(newDataSet[0]) == 1):  # 无特征值可划分
167.                 decisionTree[currentLabel][key] = majorityCateg(newDataSet)
168.             else:
169.                 decisionTree[currentLabel][key] = {}
170.                 createDecisionTreeC45(decisionTree[currentLabel][key], newDataSet, newLabels)
171. 
172. 
173. #测试数据分类
174. def classify(inputTree, featLabels, testVec):
175.     firstStr = list(inputTree.keys())#得到节点所代表的属性eg:'flippers'
176.     firstStr = firstStr[0]
177.     secondDict = inputTree[firstStr]#得到该节点的子节点,是一个dict,eg:{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
178.     featIndex = featLabels.index(firstStr)#得到firstStr在所给的featLabels(属性)中的位置,以便将testVec中的值与相应的属性对应
179.     for key in secondDict.keys():#将testVec中的值放入决策树中进行判断
180.         if testVec[featIndex] == key:
181.             if type(secondDict[key]).__name__=='dict':#如果还有子节点则继续判断
182.                 classLabel = classify(secondDict[key],featLabels,testVec)
183.             else: classLabel = secondDict[key]#否则返回该节点的值
184.     return classLabel
185. 
186. if __name__ == '__main__':
187.     dataSetID3, labelsID3 = createDataSet()
188.     testData1 = [1, 1, 1, 0]
189.     testData2 = [1,-1,0,0]
190.     bestFeatureID3 = chooseBestFeatureID3(dataSetID3, labelsID3)
191.     decisionTreeID3 = {}
192.     createDecisionTreeID3(decisionTreeID3, dataSetID3, labelsID3)
193.     print("ID3 decision tree: ", decisionTreeID3)
194.     # category1ID3 = classifyTestData(decisionTreeID3, testData1)
195.     # print(testData1 , ", classified as by ID3: " , category1ID3)
196.     # category2ID3 = classifyTestData(decisionTreeID3, testData2)
197.     # print(testData2 , ", classified as by ID3: " , category2ID3)
198. 
199.     for tmp in dataSetID3:
200.         category = classify(decisionTreeID3,labelsID3,tmp[0:4])
201.         print(tmp[0:4], ", classified as by ID3: " , category)
202. 
203.     dataSetC45, labelsC45 = createDataSet()
204.     bestFeatureC45 = chooseBestFeatureC45(dataSetC45, labelsC45)
205.     decisionTreeC45 = {}
206.     createDecisionTreeC45(decisionTreeC45, dataSetC45, labelsC45)
207.     print("C4.5 decision tree: ", decisionTreeC45)
208.     # category1C45 = classifyTestData(decisionTreeC45, testData1)
209.     # print(testData1 , ", classified as by C4.5: " , category1C45)
210.     # category2C45 = classifyTestData(decisionTreeC45, testData2)
211.     # print(testData2 , ", classified as by C4.5: " , category2C45)
212. 
213.     for tmp in dataSetC45:
214.         category = classify(decisionTreeC45,labelsC45,tmp[0:4])
215.         print(tmp[0:4], ", classified as by C4.5: " , category)

 

AIEarth是一个由众多领域内专家博主共同打造的学术平台,旨在建设一个拥抱智慧未来的学术殿堂!【平台地址:https://devpress.csdn.net/aiearth】 很高兴认识你!加入我们共同进步!

目录
相关文章
|
14天前
|
机器学习/深度学习 数据采集 算法
Python用逻辑回归、决策树、SVM、XGBoost 算法机器学习预测用户信贷行为数据分析报告
Python用逻辑回归、决策树、SVM、XGBoost 算法机器学习预测用户信贷行为数据分析报告
|
21小时前
|
机器学习/深度学习 算法
理解并应用机器学习算法:决策树
【5月更文挑战第12天】决策树是直观的分类与回归机器学习算法,通过树状结构模拟决策过程。每个内部节点代表特征属性,分支代表属性取值,叶子节点代表类别。构建过程包括特征选择(如信息增益、基尼指数等)、决策树生成和剪枝(预剪枝和后剪枝)以防止过拟合。广泛应用在信贷风险评估、医疗诊断等领域。理解并掌握决策树有助于解决实际问题。
|
3天前
|
存储 缓存 算法
数据结构与算法 树(B树,B+树,红黑树待完善)
数据结构与算法 树(B树,B+树,红黑树待完善)
11 0
|
5天前
|
Arthas 监控 算法
JVM工作原理与实战(二十五):堆的垃圾回收-垃圾回收算法
JVM作为Java程序的运行环境,其负责解释和执行字节码,管理内存,确保安全,支持多线程和提供性能监控工具,以及确保程序的跨平台运行。本文主要介绍了垃圾回收算法评价标准、标记清除算法、复制算法、标记整理算法、分代垃圾回收算法等内容。
18 0
JVM工作原理与实战(二十五):堆的垃圾回收-垃圾回收算法
|
11天前
|
机器学习/深度学习 自然语言处理 算法
机器学习算法原理与应用:深入探索与实战
【5月更文挑战第2天】本文深入探讨机器学习算法原理,包括监督学习(如线性回归、SVM、神经网络)、非监督学习(聚类、PCA)和强化学习。通过案例展示了机器学习在图像识别(CNN)、自然语言处理(RNN/LSTM)和推荐系统(协同过滤)的应用。随着技术发展,机器学习正广泛影响各领域,但也带来隐私和算法偏见问题,需关注解决。
|
12天前
|
机器学习/深度学习 算法 数据可视化
Matlab决策树、模糊C-均值聚类算法分析高校教师职称学历评分可视化
Matlab决策树、模糊C-均值聚类算法分析高校教师职称学历评分可视化
|
12天前
|
机器学习/深度学习 算法 数据可视化
【Python机器学习专栏】决策树算法的实现与解释
【4月更文挑战第30天】本文探讨了决策树算法,一种流行的监督学习方法,用于分类和回归。文章阐述了决策树的基本原理,其中内部节点代表特征判断,分支表示判断结果,叶节点代表类别。信息增益等标准用于衡量特征重要性。通过Python的scikit-learn库展示了构建鸢尾花数据集分类器的示例,包括训练、预测、评估和可视化决策树。最后,讨论了模型解释和特征重要性评估在优化中的作用。
|
14天前
|
机器学习/深度学习 算法 搜索推荐
R语言LASSO特征选择、决策树CART算法和CHAID算法电商网站购物行为预测分析
R语言LASSO特征选择、决策树CART算法和CHAID算法电商网站购物行为预测分析
|
2天前
|
算法 数据安全/隐私保护 计算机视觉
基于二维CS-SCHT变换和LABS方法的水印嵌入和提取算法matlab仿真
该内容包括一个算法的运行展示和详细步骤,使用了MATLAB2022a。算法涉及水印嵌入和提取,利用LAB色彩空间可能用于隐藏水印。水印通过二维CS-SCHT变换、低频系数处理和特定解码策略来提取。代码段展示了水印置乱、图像处理(如噪声、旋转、剪切等攻击)以及水印的逆置乱和提取过程。最后,计算并保存了比特率,用于评估水印的稳健性。
|
3天前
|
存储 算法 数据可视化
基于harris角点和RANSAC算法的图像拼接matlab仿真
本文介绍了使用MATLAB2022a进行图像拼接的流程,涉及Harris角点检测和RANSAC算法。Harris角点检测寻找图像中局部曲率变化显著的点,RANSAC则用于排除噪声和异常点,找到最佳匹配。核心程序包括自定义的Harris角点计算函数,RANSAC参数设置,以及匹配点的可视化和仿射变换矩阵计算,最终生成全景图像。