拓展:


对sklearn.datasets中的鸢尾花(Iris)数据集,按训练集:测试集=7:3构建决策树模型并对模型进行评估。

#导入模块
import pandas as pd
from sklearn.metrics import classification_report
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
#from sklearn.tree import export_graphviz
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']    #指定默认字体
plt.rcParams['axes.unicode_minus'] = False      #用来正常显示负号


#加载数据
iris = load_iris()
irisdf = pd.DataFrame(iris.data,columns=iris.feature_names)
irisdf.head(5)
#划分数据集
from sklearn import model_selection
x_train,x_test,y_train,y_test = model_selection.train_test_split(iris.data,
                                                                 iris.target,
                                                                 test_size=0.3,
                                                                 random_state=1)

#训练模型
dct = DecisionTreeClassifier()
fm = dct.fit(x_train,y_train)
pred = dct.predict(x_test)

#输出精确度、召回率和F1分数等信息
print(classification_report(y_test,pred,target_names=iris.target_names))

#可视化决策树
from sklearn import tree
tree.plot_tree(fm,filled=True,
               feature_names=iris.feature_names,
               class_names=iris.target_names)
'''
filled=True:填充颜色;
feature_names:特征变量名称
class_names:类别名称
'''

#报告模型结果 函数
def reprt_model(model,feature_name,class_name):
    '''
    model:模型;feature_name:特征变量名称;class_name:类别名称
    '''

    model_preds = model.predict(x_test)
    print(classification_report(y_test,model_preds,
                                target_names=iris.target_names))
    print('\n')
    plt.figure(figsize=(12,8),dpi=150)
    tree.plot_tree(model,filled=True,
                   feature_names=feature_name,
                   class_names=class_name)

#输出 报告模型结果
reprt_model(dct,iris.feature_names,iris.target_names)

#列联表
cross_table = pd.crosstab(y_test, pred)
print(cross_table)

输出结果:

决策树 #数据挖掘 #Python-LMLPHP

  列联表:从列联表可以看出,在测试集的45个样本中错误分类的只有2个。1个将1类误分类到2类中,一个将2类误分类到1类中。

模型评估:

 决策树 #数据挖掘 #Python-LMLPHP

06-17 14:38