2019-12-03 15:31:13
数据集类型为:
vhigh,vhigh,2,2,small,low,unacc
vhigh,vhigh,2,2,small,med,unacc
vhigh,vhigh,2,2,small,high,unacc
vhigh,vhigh,2,2,med,low,unacc
vhigh,vhigh,2,2,med,med,unacc
vhigh,vhigh,2,2,med,high,unacc
vhigh,vhigh,2,2,big,low,unacc
vhigh,vhigh,2,2,big,med,unacc
vhigh,vhigh,2,2,big,high,unacc
vhigh,vhigh,2,4,small,low,unacc
vhigh,vhigh,2,4,small,med,unacc
vhigh,vhigh,2,4,small,high,unacc
vhigh,vhigh,2,4,med,low,unacc
具体的就不一一列出了,需要原数据集的可以评论
参考了https://www.cnblogs.com/wsine/p/5180310.html
剪枝前
5 from math import log 6 import operator 7 import treeplotter 8 import pandas as pd 9 import numpy as np 10 11 def calcShannonEnt(dataSet): 12 """ 13 输入:数据集 14 输出:数据集的香农熵 15 描述:计算给定数据集的香农熵 16 """ 17 numEntries = len(dataSet) 18 labelCounts = {} 19 for featVec in dataSet: 20 currentLabel = featVec[-1] 21 if currentLabel not in labelCounts.keys(): 22 labelCounts[currentLabel] = 0 23 labelCounts[currentLabel] += 1 24 shannonEnt = 0.0 25 for key in labelCounts: 26 prob = float(labelCounts[key])/numEntries 27 shannonEnt -= prob * log(prob, 2) 28 return shannonEnt 29 30 def splitDataSet(dataSet, axis, value): 31 """ 32 输入:数据集,选择维度,选择值 33 输出:划分数据集 34 描述:按照给定特征划分数据集;去除选择维度中等于选择值的项 35 """ 36 retDataSet = [] 37 for featVec in dataSet: 38 if featVec[axis] == value: 39 reduceFeatVec = featVec[:axis] 40 reduceFeatVec.extend(featVec[axis+1:]) 41 retDataSet.append(reduceFeatVec) 42 return retDataSet 43 44 def chooseBestFeatureToSplit(dataSet): 45 """ 46 输入:数据集 47 输出:最好的划分维度 48 描述:选择最好的数据集划分维度 49 """ 50 numFeatures = len(dataSet[0]) - 1 51 baseEntropy = calcShannonEnt(dataSet) 52 bestInfoGain = 0.0 53 bestFeature = -1 54 for i in range(numFeatures): 55 featList = [example[i] for example in dataSet] 56 uniqueVals = set(featList) 57 newEntropy = 0.0 58 for value in uniqueVals: 59 subDataSet = splitDataSet(dataSet, i, value) 60 prob = len(subDataSet)/float(len(dataSet)) 61 newEntropy += prob * calcShannonEnt(subDataSet) 62 infoGain = baseEntropy - newEntropy 63 if (infoGain > bestInfoGain): 64 bestInfoGain = infoGain 65 bestFeature = i 66 return bestFeature 67 68 def majorityCnt(classList): 69 """ 70 输入:分类类别列表 71 输出:子节点的分类 72 描述:数据集已经处理了所有属性,但是类标签依然不是唯一的, 73 采用多数判决的方法决定该子节点的分类 74 """ 75 classCount = {} 76 for vote in classList: 77 if vote not in classCount.keys(): 78 classCount[vote] = 0 79 classCount[vote] += 1 80 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True) 81 return sortedClassCount[0][0] 82 83 def createTree(dataSet, labels): 84 """ 85 输入:数据集,特征标签 86 输出:决策树 87 描述:递归构建决策树,利用上述的函数 88 """ 89 classList = [example[-1] for example in dataSet] 90 if classList.count(classList[0]) == len(classList): 91 # 类别完全相同,停止划分 92 return classList[0] 93 if len(dataSet[0]) == 1: 94 # 遍历完所有特征时返回出现次数最多的 95 return majorityCnt(classList) 96 bestFeat = chooseBestFeatureToSplit(dataSet) 97 bestFeatLabel = labels[bestFeat] 98 myTree = {bestFeatLabel:{}} 99 del(labels[bestFeat]) 100 # 得到列表包括节点所有的属性值 101 featValues = [example[bestFeat] for example in dataSet] 102 uniqueVals = set(featValues) 103 for value in uniqueVals: 104 subLabels = labels[:] 105 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) 106 return myTree 107 108 def classify(inputTree, featLabels, testVec): 109 """ 110 输入:决策树,分类标签,测试数据 111 输出:决策结果 112 描述:跑决策树 113 """ 114 firstStr = list(inputTree.keys())[0] 115 secondDict = inputTree[firstStr] 116 featIndex = featLabels.index(firstStr) 117 for key in secondDict.keys(): 118 if testVec[featIndex] == key: 119 if type(secondDict[key]).__name__ == 'dict': 120 classLabel = classify(secondDict[key], featLabels, testVec) 121 else: 122 classLabel = secondDict[key] 123 return classLabel 124 125 def classifyAll(inputTree, featLabels, testDataSet): 126 """ 127 输入:决策树,分类标签,测试数据集 128 输出:决策结果 129 描述:跑决策树 130 """ 131 classLabelAll = [] 132 for testVec in testDataSet: 133 classLabelAll.append(classify(inputTree, featLabels, testVec)) 134 return classLabelAll 135 136 def storeTree(inputTree, filename): 137 """ 138 输入:决策树,保存文件路径 139 输出: 140 描述:保存决策树到文件 141 """ 142 import pickle 143 fw = open(filename, 'wb') 144 pickle.dump(inputTree, fw) 145 fw.close() 146 147 def grabTree(filename): 148 """ 149 输入:文件路径名 150 输出:决策树 151 描述:从文件读取决策树 152 """ 153 import pickle 154 fr = open(filename, 'rb') 155 return pickle.load(fr) 156 157 def createDataSet(): 158 data = pd.read_csv("car.csv") 159 train_data1=(data.replace('5more',6)).values 160 train_data = np.array(train_data1) # np.ndarray() 161 dataSet = train_data.tolist() # list 162 print(dataSet) 163 164 labels = ['buying', 'maint', 'doors', 'persons', 'lug_boot', 'safety'] 165 return dataSet, labels 166 167 168 def main(): 169 dataSet, labels = createDataSet() 170 labels_tmp = labels[:] # 拷贝,createTree会改变labels 171 desicionTree = createTree(dataSet, labels_tmp) 172 #storeTree(desicionTree, 'classifierStorage.txt') 173 #desicionTree = grabTree('classifierStorage.txt') 174 print('desicionTree:\n', desicionTree) 175 treeplotter.createPlot(desicionTree) 176 177 178 if __name__ == '__main__': 179 main()