决策树进行树叶分类实战
1. 导入数据
import pandas as pd import matplotlib.pyplot as plt from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import GridSearchCV
data = pd.read_csv('train.csv') • 1
data.head()
5 rows × 194 columns
数据说明:
species类别,64个margin边缘特征,64个shape形状特征,64个texture质感特征
一共有99个树叶类别
data.shape • 1
(990, 194) • 1
# 查看树叶类别数 len(data.species.unique())
99 • 1
2. 特征工程
# 把字符串类别转化为数字形式 lb = LabelEncoder().fit(data.species) labels = lb.transform(data.species) # 去掉'species', 'id'这两列对于训练模型无用的列 data = data.drop(['species', 'id'], axis=1) data.head()
5 rows × 192 columns
labels[:5] • 1
array([ 3, 49, 65, 94, 84], dtype=int64)
# 切分数据集 x_train,x_test,y_train,y_test = train_test_split(data, labels, test_size=0.2, stratify=labels)
3. 构建决策树模型
tree = DecisionTreeClassifier() tree.fit(x_train, y_train)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=None, splitter='best')
tree.score(x_test, y_test) • 1
0.6767676767676768 • 1
tree.score(x_train, y_train) • 1
1.0
结果表明该模型在训练集准确率为100%,而在测试集准确率仅有67%,存在过拟合现象,模型需要进一步优化。
4. 模型优化
# max_depth:树的最大深度 # min_samples_split:内部节点再划分所需最小样本数 # min_samples_leaf:叶子节点最少样本数 param_grid = {'max_depth': [10,15,20,25,30], 'min_samples_split': [2,3,4,5,6,7,8], 'min_samples_leaf':[1,2,3,4,5,6,7]} # 网格搜索 model = GridSearchCV(tree, param_grid, cv=3) model.fit(x_train, y_train) print(model.best_estimator_)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=30, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=4, min_samples_split=5, min_weight_fraction_leaf=0.0, presort=False, random_state=None, splitter='best')
model.score(x_train, y_train)
0.9444444444444444
model.score(x_test, y_test)
0.6868686868686869