#pragma once
#include "TreeNode.h"
#include<vector>
#include<Windows.h>
#include<iostream>
#include<fstream>
using namespace std;
int TreeNode::attribute_size[10] = { -1,3,3,3,3,3,2,-1,-1,2 };//每个属性分别对应着几种取值
bool TreeNode::stop(double temp_data[18][10], int size)
{
//如果所有数据属于同一类,则停止分类
int sumT = 0;
for (int i = 0; i < size; i++)
if (temp_data[i][9] == 1)
sumT++;
if (sumT == 0)
return true;
else if (sumT == size)
return true;
else
return false;
}
double TreeNode::log2x(double x)
{
double den = log10(2), num = log10(x), sum;
sum = num / den;
return sum;
}
double TreeNode::calcEntropy(int total_size, int true_size)
{
//所有实例数量、正例数量
double entropy, Texample, Fexample, Tproportion, Fproportion;//信息熵;正例信息熵;反例信息熵;正例占比;反例占比
Tproportion = true_size * 1.0 / total_size;
Fproportion = (total_size - true_size) * 1.0 / total_size;
if (Fproportion == 0)
{
entropy = (-1) * Tproportion * log2x(Tproportion);
return entropy;
}
else if (Tproportion == 0)
{
entropy = (-1) * Fproportion * log2x(Fproportion);
return entropy;
}
else
{
Texample = Tproportion * log2x(Tproportion);
Fexample = Fproportion * log2x(Fproportion);
entropy = -Texample - Fexample;
return entropy;
}
}
double TreeNode::calc_gain(int kind)
{
this->isContinuous = false;
this->breakingpoint = 0;
if (kind == 1 || kind == 2 || kind == 3 || kind == 4 || kind == 5 || kind == 6)//具有离散值的属性
{
//遍历,用于计算该节点的类别个数
int ZeroKind = 0, OneKind = 0, TwoKind = 0;
for (int i = 0; i < sizeOfWatermelons; i++)
{
if (rem_wat[i][kind] == 0)
ZeroKind++;
else if (rem_wat[i][kind] == 1)
OneKind++;
else
TwoKind++;
}
if (ZeroKind != 0 && OneKind != 0 && TwoKind != 0)//各属性中具有三个类别的信息增益
{
//统计各类别的总数、正例数量
int OneTrue = 0, TwoTrue = 0, ZeroTrue = 0, OneSum = 0, TwoSum = 0, ZeroSum = 0;
for (int i = 0; i < sizeOfWatermelons; i++)
{
if (rem_wat[i][9] == 1)//好瓜
{
if (rem_wat[i][kind] == 0)
ZeroTrue++;
else if (rem_wat[i][kind] == 1)
OneTrue++;
else
TwoTrue++;
}
else//不是好瓜
{
if (rem_wat[i][kind] == 0)
ZeroSum++;
else if (rem_wat[i][kind] == 1)
OneSum++;
else
TwoSum++;
}
}
ZeroSum += ZeroTrue; OneSum += OneTrue; TwoSum += TwoTrue;
//三种信息熵
double ZeroEntropy, OneEntropy, TwoEntropy;
ZeroEntropy = calcEntropy(ZeroSum, ZeroTrue);
OneEntropy = calcEntropy(OneSum, OneTrue);
TwoEntropy = calcEntropy(TwoSum, TwoTrue);
//信息增益
double gain = entD - ZeroSum * ZeroEntropy * 1.0 / sizeOfWatermelons - OneSum * OneEntropy * 1.0 / sizeOfWatermelons - TwoSum * TwoEntropy * 1.0 / sizeOfWatermelons;
return gain;
}
else if ((ZeroKind != 0 && OneKind != 0 && TwoKind == 0) || (ZeroKind == 0 && OneKind != 0 && TwoKind != 0) || (ZeroKind != 0 && OneKind == 0 && TwoKind != 0)) //各属性中具有两个类别的信息增益
{//统计各类别的总数、正例数量
int OneTrue = 0, ZeroTrue = 0, OneSum = 0, ZeroSum = 0;
for (int i = 0; i < sizeOfWatermelons; i++)
{
if (rem_wat[i][9] == 1)//好瓜
{
if (rem_wat[i][kind] == 0)
ZeroTrue++;
else
OneTrue++;
}
else//不是好瓜
{
if (rem_wat[i][kind] == 0)
ZeroSum++;
else
OneSum++;
}
}
ZeroSum += ZeroTrue; OneSum += OneTrue;
//两种信息熵
double ZeroEntropy, OneEntropy;
ZeroEntropy = calcEntropy(ZeroSum, ZeroTrue);
OneEntropy = calcEntropy(OneSum, OneTrue);
//信息增益
double gain = entD - ZeroSum * ZeroEntropy * 1.0 / sizeOfWatermelons - OneSum * OneEntropy * 1.0 / sizeOfWatermelons;
return gain;
}
else//属性里只有一个类别,信息熵是0;信息增益max=entD但不可以按照该类别划分,按0处理
{
return entD;
}
}
else //连续值属性
{//arrange存储连续属性数据,cQuality储存对应的好瓜判断,
double* arrange = new double[sizeOfWatermelons];
double* cQuality = new double[sizeOfWatermelons];
for (int i = 0; i < sizeOfWatermelons; i++)
{
arrange[i] = rem_wat[i][kind];
cQuality[i] = rem_wat[i][9];
}
//从小到大排序
double tempt, t;
for (int m = 0; m < sizeOfWatermelons - 1; m++)
{
int min = m;
for (int i = m + 1; i < sizeOfWatermelons; i++)
if (arrange[i] < arrange[min])
min = i;
//交换当前点和最小值点
tempt = arrange[min];
arrange[min] = arrange[m];
arrange[m] = tempt;
//交换当前点的好坏瓜和最小值点的好坏瓜
t = cQuality[min];
cQuality[min] = cQuality[m];
cQuality[m] = t;
}
//按照算法取端点值存储到point数组,二分法
double* point = new double[sizeOfWatermelons - 1];
for (int i = 0; i < sizeOfWatermelons - 1; i++)
point[i] = (arrange[i] + arrange[i + 1]) / 2;
//计算不同端点处的信息增益并储存到PointGain数组中
int i;//端点下标
double* PointGain = new double[sizeOfWatermelons - 1];
double BeforeEntropy, AfterEntropy;
//以端点为界限分两组,并统计两组的正例
for (i = 0; i < sizeOfWatermelons - 1; i++)
{
int BeforeTrue = 0, AfterTrue = 0;
for (int m = 0; m < i + 1; m++)
if (cQuality[m] == 1)
BeforeTrue++;
for (int m = i + 1; m < sizeOfWatermelons; m++)
if (cQuality[m] == 1)
AfterTrue++;
//两种信息熵,记录端点左边的和端点右边的
BeforeEntropy = calcEntropy(i + 1, BeforeTrue);
AfterEntropy = calcEntropy(sizeOfWatermelons - 1 - i, AfterTrue);
//信息增益
PointGain[i] = entD - (i + 1) * BeforeEntropy * 1.0 / sizeOfWatermelons - (sizeOfWatermelons - 1 - i) * AfterEntropy * 1.0 / sizeOfWatermelons;
}
int max = 0;
for (int i = 1; i < sizeOfWatermelons - 1; i++)
if (PointGain[i] > PointGain[max])
max = i;
this->Continuouspoint[kind-7] = point[max];
// cout << "continue=" << Continuouspoint[kind - 7]<<endl;
this->breakingpoint = this->Continuouspoint[kind - 7];
double gain = PointGain[max];
delete[]PointGain;
delete[]point;
delete[]arrange;
delete[]cQuality;
return gain;
}
}
void TreeNode::generate_childNode(double data[18][10],int size)
{
if ((this->stop(data, size)) || (size == 0) || this->isLeafNode)
{
this->isLeafNode = true;
this->isGood = bool(data[0][9]);
// cout <<bool( isGood);
return;
}
else
{
//从A中选择最优划分属性a*
this->isLeafNode = false;
this->indexOfAttribute = 0;
double maxgain = 0;
for (int i = 0; i < attSize; i++)
{
if (this->calc_gain(rem_att[i]) > maxgain)
{
indexOfAttribute = rem_att[i];
maxgain = this->calc_gain(rem_att[i]);
}
}
// cout << indexOfAttribute;
this->deleteatt(indexOfAttribute);
// cout << endl;
//生成分支
if ((indexOfAttribute >= 1) && (indexOfAttribute <= 6))
{
isContinuous = 0;
double temp_data[18][10];
int sizeOfTemp_data = 0;
for (int i = 0; i < attribute_size[indexOfAttribute]; i++)
{
sizeOfTemp_data = 0;
for (int j = 0; j < sizeOfWatermelons; j++)
{
if (data[j][indexOfAttribute] == i)//复制西瓜,归类
{
for (int k = 0; k < 10; k++)
temp_data[sizeOfTemp_data][k] = data[j][k];
sizeOfTemp_data++;
}
}
TreeNode tempTree(temp_data, sizeOfTemp_data,rem_att,attSize);
if (sizeOfTemp_data == 0)
{
tempTree.isLeafNode = true;
tempTree.isGood = bool(data[0][10]);
return;
}
tempTree.generate_childNode(temp_data,sizeOfTemp_data);
childTree.push_back(tempTree);
}
return;
}
if (indexOfAttribute == 7 || indexOfAttribute == 8)
{
isContinuous = 1;
double temp_data1[18][10];
double temp_data2[18][10];
int sizeOfTemp_data = 0;
breakingpoint = Continuouspoint[indexOfAttribute - 7];
// cout <<"breakingpoint="<< this->breakingpoint<<endl;
for (int i = 0; i < sizeOfWatermelons; i++)
{
if (data[i][indexOfAttribute] <= breakingpoint)
{
for (int k = 0; k < 10; k++)
temp_data1[sizeOfTemp_data][k] = data[i][k];
sizeOfTemp_data++;
}
}
TreeNode tempTree1(temp_data1, sizeOfTemp_data, rem_att, attSize);
if (sizeOfTemp_data == 0)
{
tempTree1.isLeafNode = true;
tempTree1.isGood = bool(data[0][10]);
return;
}
tempTree1.generate_childNode(temp_data1, sizeOfTemp_data);
childTree.push_back(tempTree1);
for (int i = 0; i < sizeOfWatermelons; i++)
if (data[i][indexOfAttribute] > breakingpoint)
{
for (int k = 0; k < 10; k++)
temp_data2[sizeOfTemp_data][k] = false;
sizeOfTemp_data++;
}
TreeNode tempTree2(temp_data2, sizeOfTemp_data, rem_att, attSize);
if (sizeOfTemp_data == 0)
{
tempTree2.isLeafNode = true;
tempTree2.isGood = bool(data[0][10]);
return;
}
tempTree2.generate_childNode(temp_data2, sizeOfTemp_data);
childTree.push_back(tempTree2);
}
}
return;
}
bool TreeNode::judgeByTree(double data[10])
{
if (isLeafNode)
return isGood;//如果这个结点就是叶结点,则叶结点的isgood属性记录了这个西瓜是好是坏,返回
else//这个结点不是叶结点
{
int i = 0;//访问哪个子结点
int index;//第几项属性
index = indexOfAttribute;
//处理判断离散值,为0则进入第0个子结点,为1则进入第1个子结点……
if ((index >= 1) && (index <= 6))
{
//attribute_size表示这种属性有几种取值
for (i = 0; i < attribute_size[index]; i++)
{
if (data[index] == i)//找到西瓜的第index项属性对应的值
break;
}
}
//判断连续值,小于间断点则进第0个子结点,大于间断点则进第1个子结点
if ((index == 7) || (index == 8))
{
if (data[index] < breakingpoint)
i = 0;
else
i = 1;
}
//交给子决策树继续判断西瓜的好坏
return childTree[i].judgeByTree(data);
}
}
void TreeNode::WriteGain()
{
//计算各属性的信息增益
double gain[9];
for (int i = 1; i < 9; i++)
gain[i] = calc_gain(i);
//写入文件
fstream binfile;
binfile.open("d:\\a\\Gain.txt", ios::binary | ios::out | ios::ate);
if (!binfile)
{
cerr << "Tdata.csv open or create error!" << endl;
exit(1);
}
for (int i = 1; i < 9; i++)//输出
binfile.write((char *)&gain[i], sizeof(double));
for (int i = 1; i < 9; i++)//输出
cout<<gain[i]<<" ";
cout<<endl;
binfile.close();
}