决策树(Decision Tree)的核心思想是:根据训练样本构建这样一棵树,使得其叶节点是分类标签,非叶节点是判断条件,这样对于一个未知样本,能在树上找到一条路径到达叶节点,就得到了它的分类。
举个简单的例子,如何识别有毒的蘑菇?如果能够得到一棵这样的决策树,那么对于一个未知的蘑菇就很容易判断出它是否有毒了。
1
2
3
4
5
6
7
8
9
10
|
它是什么颜色的?
|
-------鲜艳---------浅色----
| |
有毒 有什么气味?
|
-----刺激性--------无味-----
| |
有毒 安全
|
构建决策树有很多算法,常用的有ID3、C4.5等。本篇以ID3为研究算法。
构建决策树的关键在于每一次分支时选择哪个特征作为分界条件。这里的原则是:选择最能把数据变得有序的特征作为分界条件。所谓有序,是指划分后,每一个分支集合的分类尽可能一致。用信息论的方式表述,就是选择信息增益最大的方式划分集合。
所谓信息增益(information gain),是指变化前后熵(entropy)的增加量。为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
其中H为熵,n为分类数目,p(xi)是选择该分类的概率。
根据公式,计算一个集合熵的方式为:
1
2
3
4
5
6
7
8
|
计算每个分类出现的次数
foreach
(每一个分类)
{
计算出现概率
根据概率计算熵
累加熵
}
return
累加结果
|
判断如何划分集合,方式为:
1
2
3
4
5
6
7
|
foreach
(每一个特征)
{
计算按此特征切分时的熵
计算与切分前相比的信息增益
保留能产生最大增益的特征为切分方式
}
return
选定的特征
|
构建树节点的方法为:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
if
(集合没有特征可用了)
{
按多数原则决定此节点的分类
}
else
if
(集合中所有样本的分类都一致)
{
此标签就是节点分类
}
else
{
以最佳方式切分集合
每一种可能形成当前节点的一个分支
递归
}
|
OK,上C#版代码,DataVector和上篇文章一样,不放了,只放核心算法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
using
System;
using
System.Collections.Generic;
namespace
MachineLearning
{
/// <summary>
/// 决策树节点
/// </summary>
public
class
DecisionNode
{
/// <summary>
/// 此节点的分类标签,为空表示此节点不是叶节点
/// </summary>
public
string
Label {
get
;
set
; }
/// <summary>
/// 此节点的划分特征,为-1表示此节点是叶节点
/// </summary>
public
int
FeatureIndex {
get
;
set
; }
/// <summary>
/// 分支
/// </summary>
public
Dictionary<
string
, DecisionNode> Child {
get
;
set
; }
public
DecisionNode()
{
this
.FeatureIndex = -1;
this
.Child =
new
Dictionary<
string
, DecisionNode>();
}
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
|
using
System;
using
System.Collections.Generic;
using
System.Linq;
namespace
MachineLearning
{
/// <summary>
/// 决策树(ID3算法)
/// </summary>
public
class
DecisionTree
{
private
DecisionNode m_Tree;
/// <summary>
/// 训练
/// </summary>
/// <param name="trainingSet"></param>
public
void
Train(List<DataVector<
string
>> trainingSet)
{
var
features =
new
List<
int
>(trainingSet[0].Dimension);
for
(
int
i = 0;i < trainingSet[0].Dimension;++i)
features.Add(i);
//生成决策树
m_Tree = CreateTree(trainingSet, features);
}
/// <summary>
/// 分类
/// </summary>
/// <param name="vector"></param>
/// <returns></returns>
public
string
Classify(DataVector<
string
> vector)
{
return
Classify(vector, m_Tree);
}
/// <summary>
/// 分类
/// </summary>
/// <param name="vector"></param>
/// <param name="node"></param>
/// <returns></returns>
private
string
Classify(DataVector<
string
> vector, DecisionNode node)
{
var
label =
string
.Empty;
if
(!
string
.IsNullOrEmpty(node.Label))
{
//是叶节点,直接返回结果
label = node.Label;
}
else
{
//取需要分类的字段,继续深入
var
key = vector.Data[node.FeatureIndex];
if
(node.Child.ContainsKey(key))
label = Classify(vector, node.Child[key]);
else
label =
"[UNKNOWN]"
;
}
return
label;
}
/// <summary>
/// 创建决策树
/// </summary>
/// <param name="dataSet"></param>
/// <param name="features"></param>
/// <returns></returns>
private
DecisionNode CreateTree(List<DataVector<
string
>> dataSet, List<
int
> features)
{
var
node =
new
DecisionNode();
if
(dataSet[0].Dimension == 0)
{
//所有字段已用完,按多数原则决定Label,结束分类
node.Label = GetMajorLabel(dataSet);
}
else
if
(dataSet.Count == dataSet.Count(d =>
string
.Equals(d.Label, dataSet[0].Label)))
{
//如果数据集中的Label相同,结束分类
node.Label = dataSet[0].Label;
}
else
{
//挑选一个最佳分类,分割集合,递归
int
featureIndex = ChooseBestFeature(dataSet);
node.FeatureIndex = features[featureIndex];
var
uniqueValues = GetUniqueValues(dataSet, featureIndex);
features.RemoveAt(featureIndex);
foreach
(
var
value
in
uniqueValues)
{
node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value),
new
List<
int
>(features));
}
}
return
node;
}
/// <summary>
/// 计算给定集合的香农熵
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private
double
ComputeShannon(List<DataVector<
string
>> dataSet)
{
double
shannon = 0.0;
var
dict =
new
Dictionary<
string
,
int
>();
foreach
(
var
item
in
dataSet)
{
if
(!dict.ContainsKey(item.Label))
dict[item.Label] = 0;
dict[item.Label] += 1;
}
foreach
(
var
label
in
dict.Keys)
{
double
prob = dict[label] * 1.0 / dataSet.Count;
shannon -= prob * Math.Log(prob, 2);
}
return
shannon;
}
/// <summary>
/// 用给定的方式切分出数据子集
/// </summary>
/// <param name="dataSet"></param>
/// <param name="splitIndex"></param>
/// <param name="value"></param>
/// <returns></returns>
private
List<DataVector<
string
>> SplitDataSet(List<DataVector<
string
>> dataSet,
int
splitIndex,
string
value)
{
var
newDataSet =
new
List<DataVector<
string
>>();
foreach
(
var
item
in
dataSet)
{
//只保留指定维度上符合给定值的项
if
(item.Data[splitIndex] == value)
{
var
newItem =
new
DataVector<
string
>(item.Dimension - 1);
newItem.Label = item.Label;
Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0);
Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1);
newDataSet.Add(newItem);
}
}
return
newDataSet;
}
/// <summary>
/// 在给定的数据集上选择一个最好的切分方式
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private
int
ChooseBestFeature(List<DataVector<
string
>> dataSet)
{
int
bestFeature = 0;
double
bestInfoGain = 0.0;
double
baseShannon = ComputeShannon(dataSet);
//遍历每一个维度来寻找
for
(
int
i = 0;i < dataSet[0].Dimension;++i)
{
var
uniqueValues = GetUniqueValues(dataSet, i);
double
newShannon = 0.0;
//遍历此维度下的每一个可能值,切分数据集并计算熵
foreach
(
var
value
in
uniqueValues)
{
var
subSet = SplitDataSet(dataSet, i, value);
double
prob = subSet.Count * 1.0 / dataSet.Count;
newShannon += prob * ComputeShannon(subSet);
}
//计算信息增益,保留最佳切分方式
double
infoGain = baseShannon - newShannon;
if
(infoGain > bestInfoGain)
{
bestInfoGain = infoGain;
bestFeature = i;
}
}
return
bestFeature;
}
/// <summary>
/// 数据去重
/// </summary>
/// <param name="dataSet"></param>
/// <param name="index"></param>
/// <returns></returns>
private
List<
string
> GetUniqueValues(List<DataVector<
string
>> dataSet,
int
index)
{
var
dict =
new
Dictionary<
string
,
int
>();
foreach
(
var
item
in
dataSet)
{
dict[item.Data[index]] = 0;
}
return
dict.Keys.ToList<
string
>();
}
/// <summary>
/// 取多数标签
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private
string
GetMajorLabel(List<DataVector<
string
>> dataSet)
{
var
dict =
new
Dictionary<
string
,
int
>();
foreach
(
var
item
in
dataSet)
{
if
(!dict.ContainsKey(item.Label))
dict[item.Label] = 0;
dict[item.Label]++;
}
string
label =
string
.Empty;
int
count = -1;
foreach
(
var
key
in
dict.Keys)
{
if
(dict[key] > count)
{
label = key;
count = dict[key];
}
}
return
label;
}
}
}
|
拿个例子实际检验一下,还是以毒蘑菇的识别为例,从这里找了点数据,http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data ,它整理了8000多个样本,每个样本描述了蘑菇的22个属性,比如形状、气味等等,然后给出了这个蘑菇是否可食用。
比如一行数据:p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
第0个元素p表示poisonous(有毒),其它22个元素分别是蘑菇的属性,可以参见agaricus-lepiota.names的描述,但实际上根本不用关心具体含义。以此构建样本并测试错误率:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
public
void
TestDecisionTree()
{
var
trainingSet =
new
List<DataVector<
string
>>();
//训练数据集
var
testSet =
new
List<DataVector<
string
>>();
//测试数据集
//读取数据
var
file =
new
StreamReader(
"agaricus-lepiota.data"
, Encoding.Default);
string
line =
string
.Empty;
int
count = 0;
while
((line = file.ReadLine()) !=
null
)
{
var
parts = line.Split(
','
);
var
p =
new
DataVector<
string
>(22);
p.Label = parts[0];
for
(
int
i = 0;i < p.Dimension;++i)
p.Data[i] = parts[i + 1];
//前7000作为训练样本,其余作为测试样本
if
(++count <= 7000)
trainingSet.Add(p);
else
testSet.Add(p);
}
file.Close();
//检验
var
dt =
new
DecisionTree();
dt.Train(trainingSet);
int
error = 0;
foreach
(
var
p
in
testSet)
{
//做猜测分类,并与实际结果比较
var
label = dt.Classify(p);
if
(label != p.Label)
++error;
}
Console.WriteLine(
"Error = {0}/{1}, {2}%"
, error, testSet.Count, (error * 100.0 / testSet.Count));
}
|
使用7000个样本做训练,1124个样本做测试,只有4个猜测出错,错误率仅有0.35%,相当不错的结果。
生成的决策树是这样的:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
{
"FeatureIndex"
: 4,
//按第4个特征划分
"Child"
: {
"p"
: {
"Label"
:
"p"
},
//如果第4个特征是p,则分类为p
"a"
: {
"Label"
:
"e"
},
//如果第4个特征是a,则分类是e
"l"
: {
"Label"
:
"e"
},
"n"
: {
"FeatureIndex"
: 19,
//如果第4个特征是n,要继续按第19个特征划分
"Child"
: {
"n"
: {
"Label"
:
"e"
},
"k"
: {
"Label"
:
"e"
},
"w"
: {
"FeatureIndex"
: 21,
"Child"
: {
"w"
: {
"Label"
:
"e"
},
"l"
: {
"FeatureIndex"
: 2,
"Child"
: {
"c"
: {
"Label"
:
"e"
},
"n"
: {
"Label"
:
"e"
},
"w"
: {
"Label"
:
"p"
},
"y"
: {
"Label"
:
"p"
}
}
},
"d"
: {
"FeatureIndex"
: 1,
"Child"
: {
"y"
: {
"Label"
:
"p"
},
"f"
: {
"Label"
:
"p"
},
"s"
: {
"Label"
:
"e"
}
}
},
"g"
: {
"Label"
:
"e"
},
"p"
: {
"Label"
:
"e"
}
}
},
"h"
: {
"Label"
:
"e"
},
"r"
: {
"Label"
:
"p"
},
"o"
: {
"Label"
:
"e"
},
"y"
: {
"Label"
:
"e"
},
"b"
: {
"Label"
:
"e"
}
}
},
"f"
: {
"Label"
:
"p"
},
"c"
: {
"Label"
:
"p"
},
"y"
: {
"Label"
:
"p"
},
"s"
: {
"Label"
:
"p"
},
"m"
: {
"Label"
:
"p"
}
}
}
|
可以看到,实际只使用了其中的5个特征,就能做出比较精确的判断了。
决策树还有一个很棒的优点就是能告诉我们多个特征中哪个对判别最有用,比如上面的树,根节点是特征4,参考agaricus-lepiota.names得知这个特征是指气味(odor),只要有气味,就可以直接得出结论,如果是无味的(n=none),下一个重要特征是19,即孢子印的颜色(spore-print-color)。
附件:http://down.51cto.com/data/2365013
本文转自 BoyTNT 51CTO博客,原文链接:http://blog.51cto.com/boytnt/1569763,如需转载请自行联系原作者