本篇决策树算法是依据ID3算法来的,所以在看之间建议先了解ID3算法:https://blog.csdn.net/qq_27396861/article/details/88226296
文章目录
图一:
图二:
第一步:构建决策树
实例:
# coding: utf-8
from math import log
import operator
def createDataSet():
dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEnt(dataSet):
''' 计算给定总的数据集的香农熵 '''
numEntries = len( dataSet ) # 集合里元素的数量
labelCounts = {}
# 为所有可能的分类创建次数字典
for featVec in dataSet:
currentLabel = featVec[-1] # 倒数第一个元素
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# 将当前样本的数量除以总的样本数量
prob = float(labelCounts[key]) / numEntries
# 以2为底求对数,然后求和
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
''' 按照给定特征划分数据集, 待划分的数据集 划分数据集的特征(dataSet里特征的下标) 特征的返回值 '''
''' 如果axis下标是0,那么reduceFeatVec就是后两个,如果是1,那么就是第一个和第三个。也就是除了axisd的下标'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
''' 选择最好的数据集划分方式 '''
numFeatures = len(dataSet[0]) - 1 # 每个元素的特征个数
# 香农熵
baseEntropy = calcShannonEnt(dataSet) # 求取数据集合的香农熵
bestInfoGain = 0.0 # 最好的熵
bestFeature = -1 # 最好的特征
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # 符合的数量 / 总的数量
# 条件熵
newEntropy += prob * calcShannonEnt(subDataSet) # 划分完之后的信息熵,相加
#信息增益
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
# 返回0是不浮出水面也能生活,1是否有脚蹼
return bestFeature
def majorityCnt(classList):
''' 找出数量最多的分类 '''
# 分类字典
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# 以第二列的数据排序
sortedClassCount = sorted(classCount.iteritems(), \
key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
''' 创建树的函数代码 '''
classList = [example[-1] for example in dataSet] # 提取所有的类
print "classList = ", classList
# 数据集都是同一类的情况
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果数据集只有一个特征的情况
if len(dataSet[0]) == 1:
return majorityCnt(classList) # 那就按大多数的分类
bestFeat = chooseBestFeatureToSplit( dataSet ) # 最好的特征
bestFeatLabel = labels[bestFeat] # 最好的分类
myTree = {bestFeatLabel:{}}
# 递归建树
del( labels[bestFeat] )
featValue = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValue) # 最好的特征集合
for value in uniqueVals:
subLabels = labels[:] # 去掉前面标签之后剩下的标签
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet, bestFeat, value), subLabels)
return myTree
def main():
dataSet, labels = createDataSet()
myTree = createTree(dataSet, labels)
print myTree
if __name__=="__main__":
main()
结果与上述图二一致:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}