前言

下面将对数据利用决策树算法得到结果。

代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
# 导入决策树分类器
from sklearn.tree import DecisionTreeClassifier
# 导入训练集划分
from sklearn.model_selection import train_test_split

'''
version:1.0
Method:DecisionTree
'''

# 开始读取文件,文件路径
readFileName="python\python Data analysis and mining\class\dataset\german.xls"
# 读取excel
df=pd.read_excel(readFileName)
# x---前20列属性    一千行,二十列
x=df.ix[:,:-1]
# y---第21列标签    一千行,一列
y=df.ix[:,-1]
# 获得属性名
names=x.columns

# 获得train set 和 test set, random_state=0可删除,删除后每次运行程序结果不太一样,不删除的话为伪随机
x_train,x_test,y_train,y_test=train_test_split(x,y,random_state=0)

list_average_accuracy=[]
depth=range(1,30)
for i in depth:
    # max_depth=4
    # 限制决策树深度可以降低算法复杂度,获取更精确值
    tree = DecisionTreeClassifier(max_depth=i,random_state=0) #可删除后面的以实现随机
    # 开始训练
    tree.fit(x_train,y_train)
    # 训练集score
    accuracy_training=tree.score(x_train,y_train)
    # 测试集score
    accuracy_test=tree.score(x_test,y_test)
    # 平均score
    average_accuracy=(accuracy_training+accuracy_test)/2.0
    print("depth %d average_accuracy:" % (i+1),average_accuracy)
    list_average_accuracy.append(average_accuracy)

# 获得score最大的对应score值
max_value=max(list_average_accuracy)
# 获得score最大深度对应的索引,索引是0开头,结果要加1
best_depth=list_average_accuracy.index(max_value)+1
print("best_depth:",best_depth)


# 把之前对应的for循环中的最优深度单独拿出来构造最优的树,并且输出
best_tree= DecisionTreeClassifier(max_depth=best_depth,random_state=0)
best_tree.fit(x_train,y_train)
accuracy_training=best_tree.score(x_train,y_train)
accuracy_test=best_tree.score(x_test,y_test)
print("decision tree:")
print("accuracy on the training subset:{:.3f}".format(best_tree.score(x_train,y_train)))
print("accuracy on the test subset:{:.3f}".format(best_tree.score(x_test,y_test)))


n_features=x.shape[1]   # 值为20,即列数
plt.barh(range(n_features),best_tree.feature_importances_,align='center')
plt.yticks(np.arange(n_features),names)
print('a')
plt.title("Decision Tree:")
plt.xlabel('Feature Importance')
plt.ylabel('Feature')
plt.show()

12-25 23:41