决策树(Decision Tree)的核心思想是:根据训练样本构建这样一棵树,使得其叶节点是分类标签,非叶节点是判断条件,这样对于一个未知样本,能在树上找到一条路径到达叶节点,就得到了它的分类。


举个简单的例子,如何识别有毒的蘑菇?如果能够得到一棵这样的决策树,那么对于一个未知的蘑菇就很容易判断出它是否有毒了。

1
2
3
4
5
6
7
8
9
10
                           
                           它是什么颜色的?
                                |
                  -------鲜艳---------浅色----
                 |                           |
               有毒                      有什么气味?
                                             |
                               -----刺激性--------无味-----
                              |                           |
                             有毒                        安全


构建决策树有很多算法,常用的有ID3、C4.5等。本篇以ID3为研究算法。


构建决策树的关键在于每一次分支时选择哪个特征作为分界条件。这里的原则是:选择最能把数据变得有序的特征作为分界条件。所谓有序,是指划分后,每一个分支集合的分类尽可能一致。用信息论的方式表述,就是选择信息增益最大的方式划分集合。


所谓信息增益(information gain),是指变化前后熵(entropy)的增加量。为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

wKiom1RR9xfQxUiYAAAdS8omkJM195.jpg

其中H为熵,n为分类数目,p(xi)是选择该分类的概率。


根据公式,计算一个集合熵的方式为:

1
2
3
4
5
6
7
8
计算每个分类出现的次数
foreach (每一个分类)
{
     计算出现概率
     根据概率计算熵
     累加熵
}
return  累加结果


判断如何划分集合,方式为:

1
2
3
4
5
6
7
foreach (每一个特征)
{
     计算按此特征切分时的熵
     计算与切分前相比的信息增益
     保留能产生最大增益的特征为切分方式
}
return  选定的特征


构建树节点的方法为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
if (集合没有特征可用了)
{
     按多数原则决定此节点的分类
}
else  if (集合中所有样本的分类都一致)
{
     此标签就是节点分类
}
else
{
     以最佳方式切分集合
     每一种可能形成当前节点的一个分支
     递归
}


OK,上C#版代码,DataVector和上篇文章一样,不放了,只放核心算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
using  System;
using  System.Collections.Generic;
 
namespace  MachineLearning
{
     /// <summary>
     /// 决策树节点
     /// </summary>
     public  class  DecisionNode
     {
         /// <summary>
         /// 此节点的分类标签,为空表示此节点不是叶节点
         /// </summary>
         public  string  Label {  get set ; }
         /// <summary>
         /// 此节点的划分特征,为-1表示此节点是叶节点
         /// </summary>
         public  int  FeatureIndex {  get set ; }
         /// <summary>
         /// 分支
         /// </summary>
         public  Dictionary< string , DecisionNode> Child {  get set ; }
 
         public  DecisionNode()
         {
             this .FeatureIndex = -1;
             this .Child =  new  Dictionary< string , DecisionNode>();
         }
     }
}


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
using  System;
using  System.Collections.Generic;
using  System.Linq;
 
namespace  MachineLearning
{
     /// <summary>
     /// 决策树(ID3算法)
     /// </summary>
     public  class  DecisionTree
     {
         private  DecisionNode m_Tree;
 
         /// <summary>
         /// 训练
         /// </summary>
         /// <param name="trainingSet"></param>
         public  void  Train(List<DataVector< string >> trainingSet)
         {
             var  features =  new  List< int >(trainingSet[0].Dimension);
             for ( int  i = 0;i < trainingSet[0].Dimension;++i)
                 features.Add(i);
                 
             //生成决策树
             m_Tree = CreateTree(trainingSet, features);
         }
 
         /// <summary>
         /// 分类
         /// </summary>
         /// <param name="vector"></param>
         /// <returns></returns>
         public  string  Classify(DataVector< string > vector)
         {
             return  Classify(vector, m_Tree);
         }
 
         /// <summary>
         /// 分类
         /// </summary>
         /// <param name="vector"></param>
         /// <param name="node"></param>
         /// <returns></returns>
         private  string  Classify(DataVector< string > vector, DecisionNode node)
         {
             var  label =  string .Empty;
             
             if (! string .IsNullOrEmpty(node.Label))
             {
                 //是叶节点,直接返回结果
                 label = node.Label;
             }
             else
             {
                 //取需要分类的字段,继续深入
                 var  key = vector.Data[node.FeatureIndex];
                 if (node.Child.ContainsKey(key))
                     label = Classify(vector, node.Child[key]);
                 else
                     label =  "[UNKNOWN]" ;
             }
             return  label;
         }
         
         /// <summary>
         /// 创建决策树
         /// </summary>
         /// <param name="dataSet"></param>
         /// <param name="features"></param>
         /// <returns></returns>
         private  DecisionNode CreateTree(List<DataVector< string >> dataSet, List< int > features)
         {
             var  node =  new  DecisionNode();
             
             if (dataSet[0].Dimension == 0)
             {
                 //所有字段已用完,按多数原则决定Label,结束分类
                 node.Label = GetMajorLabel(dataSet);
             }
             else  if (dataSet.Count == dataSet.Count(d =>  string .Equals(d.Label, dataSet[0].Label)))
             {
                 //如果数据集中的Label相同,结束分类
                 node.Label = dataSet[0].Label;
             }
             else
             {
                 //挑选一个最佳分类,分割集合,递归
                 int  featureIndex = ChooseBestFeature(dataSet);
                 node.FeatureIndex = features[featureIndex];
                 var  uniqueValues = GetUniqueValues(dataSet, featureIndex);
                 features.RemoveAt(featureIndex);
                 foreach ( var  value  in  uniqueValues)
                 {
                     node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value),  new  List< int >(features));
                 }
             }
             
             return  node;
         }
         
         /// <summary>
         /// 计算给定集合的香农熵
         /// </summary>
         /// <param name="dataSet"></param>
         /// <returns></returns>
         private  double  ComputeShannon(List<DataVector< string >> dataSet)
         {
             double  shannon = 0.0;
             
             var  dict =  new  Dictionary< string int >();
             foreach ( var  item  in  dataSet)
             {
                 if (!dict.ContainsKey(item.Label))
                     dict[item.Label] = 0;
                 dict[item.Label] += 1;
             }
             
             foreach ( var  label  in  dict.Keys)
             {
                 double  prob = dict[label] * 1.0 / dataSet.Count;
                 shannon -= prob * Math.Log(prob, 2);
             }
             
             return  shannon;
         }
         
         /// <summary>
         /// 用给定的方式切分出数据子集
         /// </summary>
         /// <param name="dataSet"></param>
         /// <param name="splitIndex"></param>
         /// <param name="value"></param>
         /// <returns></returns>
         private  List<DataVector< string >> SplitDataSet(List<DataVector< string >> dataSet,  int  splitIndex,  string  value)
         {
             var  newDataSet =  new  List<DataVector< string >>();
             
             foreach ( var  item  in  dataSet)
             {
                 //只保留指定维度上符合给定值的项
                 if (item.Data[splitIndex] == value)
                 {
                     var  newItem =  new  DataVector< string >(item.Dimension - 1);
                     newItem.Label = item.Label;
                     Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0);
                     Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1);
                     newDataSet.Add(newItem);
                 }
             }
             
             return  newDataSet;
         }
 
         /// <summary>
         /// 在给定的数据集上选择一个最好的切分方式
         /// </summary>
         /// <param name="dataSet"></param>
         /// <returns></returns>
         private  int  ChooseBestFeature(List<DataVector< string >> dataSet)
         {
             int  bestFeature = 0;
             double  bestInfoGain = 0.0;
             double  baseShannon = ComputeShannon(dataSet);
             
             //遍历每一个维度来寻找
             for ( int  i = 0;i < dataSet[0].Dimension;++i)
             {
                 var  uniqueValues = GetUniqueValues(dataSet, i);
                 double  newShannon = 0.0;
 
                 //遍历此维度下的每一个可能值,切分数据集并计算熵
                 foreach ( var  value  in  uniqueValues)
                 {
                     var  subSet = SplitDataSet(dataSet, i, value);
                     double  prob = subSet.Count * 1.0 / dataSet.Count;
                     newShannon += prob * ComputeShannon(subSet);
                 }
 
                 //计算信息增益,保留最佳切分方式
                 double  infoGain = baseShannon - newShannon;
                 if (infoGain > bestInfoGain)
                 {
                     bestInfoGain = infoGain;
                     bestFeature = i;
                 }
             }
             
             return  bestFeature;
         }
 
         /// <summary>
         /// 数据去重
         /// </summary>
         /// <param name="dataSet"></param>
         /// <param name="index"></param>
         /// <returns></returns>
         private  List< string > GetUniqueValues(List<DataVector< string >> dataSet,  int  index)
         {
             var  dict =  new  Dictionary< string int >();
             foreach ( var  item  in  dataSet)
             {
                 dict[item.Data[index]] = 0;
             }
             return  dict.Keys.ToList< string >();
         }
 
         /// <summary>
         /// 取多数标签
         /// </summary>
         /// <param name="dataSet"></param>
         /// <returns></returns>
         private  string  GetMajorLabel(List<DataVector< string >> dataSet)
         {
             var  dict =  new  Dictionary< string int >();
             foreach ( var  item  in  dataSet)
             {
                 if (!dict.ContainsKey(item.Label))
                     dict[item.Label] = 0;
                 dict[item.Label]++;
             }
 
             string  label =  string .Empty;
             int  count = -1;
             foreach ( var  key  in  dict.Keys)
             {
                 if (dict[key] > count)
                 {
                     label = key;
                     count = dict[key];
                 }
             }
             
             return  label;
         }
     }
}



拿个例子实际检验一下,还是以毒蘑菇的识别为例,从这里找了点数据,http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data ,它整理了8000多个样本,每个样本描述了蘑菇的22个属性,比如形状、气味等等,然后给出了这个蘑菇是否可食用。


比如一行数据:p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u

第0个元素p表示poisonous(有毒),其它22个元素分别是蘑菇的属性,可以参见agaricus-lepiota.names的描述,但实际上根本不用关心具体含义。以此构建样本并测试错误率:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public  void  TestDecisionTree()
{
     var  trainingSet =  new  List<DataVector< string >>();     //训练数据集
     var  testSet =  new  List<DataVector< string >>();         //测试数据集
     
     //读取数据
     var  file =  new  StreamReader( "agaricus-lepiota.data" , Encoding.Default);
     string  line =  string .Empty;
     int  count = 0;
     while ((line = file.ReadLine()) !=  null )
     {
         var  parts = line.Split( ',' );
         
         var  p =  new  DataVector< string >(22);
         p.Label = parts[0];
         for ( int  i = 0;i < p.Dimension;++i)
             p.Data[i] = parts[i + 1];
             
         //前7000作为训练样本,其余作为测试样本
         if (++count <= 7000)
             trainingSet.Add(p);
         else
             testSet.Add(p);
     }
     file.Close();
 
     //检验
     var  dt =  new  DecisionTree();
     dt.Train(trainingSet);
     int  error = 0;
     foreach ( var  in  testSet)
     {
         //做猜测分类,并与实际结果比较
         var  label = dt.Classify(p);
         if (label != p.Label)
             ++error;
     }
 
     Console.WriteLine( "Error = {0}/{1}, {2}%" , error, testSet.Count, (error * 100.0 / testSet.Count));
}


使用7000个样本做训练,1124个样本做测试,只有4个猜测出错,错误率仅有0.35%,相当不错的结果。


生成的决策树是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
{
     "FeatureIndex" : 4,               //按第4个特征划分
     "Child" : {
         "p" : { "Label" "p" },         //如果第4个特征是p,则分类为p
         "a" : { "Label" "e" },         //如果第4个特征是a,则分类是e
         "l" : { "Label" "e" },
         "n" : {
             "FeatureIndex" : 19,             //如果第4个特征是n,要继续按第19个特征划分
             "Child" : {
                 "n" : { "Label" "e" },
                 "k" : { "Label" "e" },
                 "w" : {
                     "FeatureIndex" : 21,
                     "Child" : {
                         "w" : { "Label" "e" },
                         "l" : {
                             "FeatureIndex" : 2,
                             "Child" : {
                                 "c" : { "Label" "e" },
                                 "n" : { "Label" "e" },
                                 "w" : { "Label" "p" },
                                 "y" : { "Label" "p" }
                             }
                         },
                         "d" : {
                             "FeatureIndex" : 1,
                             "Child" : {
                                 "y" : { "Label" "p" },
                                 "f" : { "Label" "p" },
                                 "s" : { "Label" "e" }
                             }
                         },
                         "g" : { "Label" "e" },
                         "p" : { "Label" "e" }
                     }
                 },
                 "h" : { "Label" "e" },
                 "r" : { "Label" "p" },
                 "o" : { "Label" "e" },
                 "y" : { "Label" "e" },
                 "b" : { "Label" "e" }
             }
         },
         "f" : { "Label" "p" },
         "c" : { "Label" "p" },
         "y" : { "Label" "p" },
         "s" : { "Label" "p" },
         "m" : { "Label" "p" }
     }
}

可以看到,实际只使用了其中的5个特征,就能做出比较精确的判断了。


决策树还有一个很棒的优点就是能告诉我们多个特征中哪个对判别最有用,比如上面的树,根节点是特征4,参考agaricus-lepiota.names得知这个特征是指气味(odor),只要有气味,就可以直接得出结论,如果是无味的(n=none),下一个重要特征是19,即孢子印的颜色(spore-print-color)。