假设我们有一份CSV文件(以部分为例):car_rf.csv
要用随机森林对其进行分类,其中最后一列视为标签,其余列视为特征
# coding = utf-8 import pandas as pd from sklearn.ensemble import RandomForestClassifier from IPython.display import Image from sklearn import tree import pydotplus def read_dataset(fname = u"/car_rf.csv"): data = pd.read_csv(fname, index_col=0,encoding="utf-8",dtype=str) data = data.fillna(0) temp_col_list = ["",""] # ""中填特征的列名 for i in temp_col_list: lables = data[i].unique().tolist() data[i] = data[i].apply(lambda n: lables.index(n)) return data train = read_dataset() # ""中填标签的列名 y = train[""].values X = train.drop([""], axis=1).values rf = RandomForestClassifier(n_estimators=4, max_depth=2) rf = rf.fit(X,y) Estimators = rf.estimators_ for index, model in enumerate(Estimators): filename = str(index) + '.pdf' dot_data = tree.export_graphviz(model , out_file=None) graph = pydotplus.graph_from_dot_data(dot_data) Image(graph.create_png()) graph.write_pdf(filename)