2019-12-03 15:31:13

数据集类型为:

vhigh,vhigh,2,2,small,low,unacc
vhigh,vhigh,2,2,small,med,unacc
vhigh,vhigh,2,2,small,high,unacc
vhigh,vhigh,2,2,med,low,unacc
vhigh,vhigh,2,2,med,med,unacc
vhigh,vhigh,2,2,med,high,unacc
vhigh,vhigh,2,2,big,low,unacc
vhigh,vhigh,2,2,big,med,unacc
vhigh,vhigh,2,2,big,high,unacc
vhigh,vhigh,2,4,small,low,unacc
vhigh,vhigh,2,4,small,med,unacc
vhigh,vhigh,2,4,small,high,unacc
vhigh,vhigh,2,4,med,low,unacc

具体的就不一一列出了,需要原数据集的可以评论

参考了https://www.cnblogs.com/wsine/p/5180310.html

剪枝前

  5 from math import log
  6 import operator
  7 import treeplotter
  8 import pandas as pd
  9 import numpy as np
 10
 11 def calcShannonEnt(dataSet):
 12     """
 13     输入:数据集
 14     输出:数据集的香农熵
 15     描述:计算给定数据集的香农熵
 16     """
 17     numEntries = len(dataSet)
 18     labelCounts = {}
 19     for featVec in dataSet:
 20         currentLabel = featVec[-1]
 21         if currentLabel not in labelCounts.keys():
 22             labelCounts[currentLabel] = 0
 23         labelCounts[currentLabel] += 1
 24     shannonEnt = 0.0
 25     for key in labelCounts:
 26         prob = float(labelCounts[key])/numEntries
 27         shannonEnt -= prob * log(prob, 2)
 28     return shannonEnt
 29
 30 def splitDataSet(dataSet, axis, value):
 31     """
 32     输入:数据集,选择维度,选择值
 33     输出:划分数据集
 34     描述:按照给定特征划分数据集;去除选择维度中等于选择值的项
 35     """
 36     retDataSet = []
 37     for featVec in dataSet:
 38         if featVec[axis] == value:
 39             reduceFeatVec = featVec[:axis]
 40             reduceFeatVec.extend(featVec[axis+1:])
 41             retDataSet.append(reduceFeatVec)
 42     return retDataSet
 43
 44 def chooseBestFeatureToSplit(dataSet):
 45     """
 46     输入:数据集
 47     输出:最好的划分维度
 48     描述:选择最好的数据集划分维度
 49     """
 50     numFeatures = len(dataSet[0]) - 1
 51     baseEntropy = calcShannonEnt(dataSet)
 52     bestInfoGain = 0.0
 53     bestFeature = -1
 54     for i in range(numFeatures):
 55         featList = [example[i] for example in dataSet]
 56         uniqueVals = set(featList)
 57         newEntropy = 0.0
 58         for value in uniqueVals:
 59             subDataSet = splitDataSet(dataSet, i, value)
 60             prob = len(subDataSet)/float(len(dataSet))
 61             newEntropy += prob * calcShannonEnt(subDataSet)
 62         infoGain = baseEntropy - newEntropy
 63         if (infoGain > bestInfoGain):
 64             bestInfoGain = infoGain
 65             bestFeature = i
 66     return bestFeature
 67
 68 def majorityCnt(classList):
 69     """
 70     输入:分类类别列表
 71     输出:子节点的分类
 72     描述:数据集已经处理了所有属性,但是类标签依然不是唯一的,
 73           采用多数判决的方法决定该子节点的分类
 74     """
 75     classCount = {}
 76     for vote in classList:
 77         if vote not in classCount.keys():
 78             classCount[vote] = 0
 79         classCount[vote] += 1
 80     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)
 81     return sortedClassCount[0][0]
 82
 83 def createTree(dataSet, labels):
 84     """
 85     输入:数据集,特征标签
 86     输出:决策树
 87     描述:递归构建决策树,利用上述的函数
 88     """
 89     classList = [example[-1] for example in dataSet]
 90     if classList.count(classList[0]) == len(classList):
 91         # 类别完全相同,停止划分
 92         return classList[0]
 93     if len(dataSet[0]) == 1:
 94         # 遍历完所有特征时返回出现次数最多的
 95         return majorityCnt(classList)
 96     bestFeat = chooseBestFeatureToSplit(dataSet)
 97     bestFeatLabel = labels[bestFeat]
 98     myTree = {bestFeatLabel:{}}
 99     del(labels[bestFeat])
100     # 得到列表包括节点所有的属性值
101     featValues = [example[bestFeat] for example in dataSet]
102     uniqueVals = set(featValues)
103     for value in uniqueVals:
104         subLabels = labels[:]
105         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
106     return myTree
107
108 def classify(inputTree, featLabels, testVec):
109     """
110     输入:决策树,分类标签,测试数据
111     输出:决策结果
112     描述:跑决策树
113     """
114     firstStr = list(inputTree.keys())[0]
115     secondDict = inputTree[firstStr]
116     featIndex = featLabels.index(firstStr)
117     for key in secondDict.keys():
118         if testVec[featIndex] == key:
119             if type(secondDict[key]).__name__ == 'dict':
120                 classLabel = classify(secondDict[key], featLabels, testVec)
121             else:
122                 classLabel = secondDict[key]
123     return classLabel
124
125 def classifyAll(inputTree, featLabels, testDataSet):
126     """
127     输入:决策树,分类标签,测试数据集
128     输出:决策结果
129     描述:跑决策树
130     """
131     classLabelAll = []
132     for testVec in testDataSet:
133         classLabelAll.append(classify(inputTree, featLabels, testVec))
134     return classLabelAll
135
136 def storeTree(inputTree, filename):
137     """
138     输入:决策树,保存文件路径
139     输出:
140     描述:保存决策树到文件
141     """
142     import pickle
143     fw = open(filename, 'wb')
144     pickle.dump(inputTree, fw)
145     fw.close()
146
147 def grabTree(filename):
148     """
149     输入:文件路径名
150     输出:决策树
151     描述:从文件读取决策树
152     """
153     import pickle
154     fr = open(filename, 'rb')
155     return pickle.load(fr)
156
157 def createDataSet():
158     data = pd.read_csv("car.csv")
159     train_data1=(data.replace('5more',6)).values
160     train_data = np.array(train_data1)  # np.ndarray()
161     dataSet = train_data.tolist()  # list
162     print(dataSet)
163
164     labels = ['buying', 'maint', 'doors', 'persons', 'lug_boot', 'safety']
165     return dataSet, labels
166
167
168 def main():
169     dataSet, labels = createDataSet()
170     labels_tmp = labels[:] # 拷贝,createTree会改变labels
171     desicionTree = createTree(dataSet, labels_tmp)
172     #storeTree(desicionTree, 'classifierStorage.txt')
173     #desicionTree = grabTree('classifierStorage.txt')
174     print('desicionTree:\n', desicionTree)
175     treeplotter.createPlot(desicionTree)
176
177
178 if __name__ == '__main__':
179     main()
12-25 08:18