本篇决策树算法是依据ID3算法来的,所以在看之间建议先了解ID3算法:https://blog.csdn.net/qq_27396861/article/details/88226296

文章目录

图一:
机器学习实战-决策树算法-LMLPHP
图二:
机器学习实战-决策树算法-LMLPHP

第一步:构建决策树

实例:

# coding: utf-8

from math import log
import operator

def createDataSet():
	dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']]
	labels = ['no surfacing', 'flippers']
	return dataSet, labels

def calcShannonEnt(dataSet):
	''' 计算给定总的数据集的香农熵 '''
	numEntries = len( dataSet ) # 集合里元素的数量
	labelCounts = {}

	# 为所有可能的分类创建次数字典
	for featVec in dataSet:
		currentLabel = featVec[-1] # 倒数第一个元素
		if currentLabel not in labelCounts.keys():
			labelCounts[currentLabel] = 0
		labelCounts[currentLabel] += 1

	shannonEnt = 0.0
	for key in labelCounts:
		# 将当前样本的数量除以总的样本数量
		prob = float(labelCounts[key]) / numEntries
		# 以2为底求对数,然后求和
		shannonEnt -= prob * log(prob, 2)

	return shannonEnt

def splitDataSet(dataSet, axis, value):
	''' 按照给定特征划分数据集, 待划分的数据集 划分数据集的特征(dataSet里特征的下标) 特征的返回值 '''
	''' 如果axis下标是0,那么reduceFeatVec就是后两个,如果是1,那么就是第一个和第三个。也就是除了axisd的下标'''
	retDataSet = []
	for featVec in dataSet:
		if featVec[axis] == value:
			reduceFeatVec = featVec[:axis]
			reduceFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reduceFeatVec)
	return retDataSet

def chooseBestFeatureToSplit(dataSet):
	''' 选择最好的数据集划分方式 '''
	numFeatures = len(dataSet[0]) - 1 # 每个元素的特征个数
	# 香农熵
	baseEntropy = calcShannonEnt(dataSet) # 求取数据集合的香农熵

	bestInfoGain = 0.0	# 最好的熵
	bestFeature = -1	# 最好的特征

	for i in range(numFeatures):
		featList = [example[i] for example in dataSet]
		uniqueVals = set(featList)
		newEntropy = 0.0

		for value in uniqueVals:
			subDataSet = splitDataSet(dataSet, i, value)
			prob = len(subDataSet) / float(len(dataSet)) # 符合的数量 / 总的数量
			# 条件熵
			newEntropy += prob * calcShannonEnt(subDataSet) # 划分完之后的信息熵,相加

		#信息增益
		infoGain = baseEntropy - newEntropy
		if(infoGain > bestInfoGain):
			bestInfoGain = infoGain
			bestFeature = i

	# 返回0是不浮出水面也能生活,1是否有脚蹼
	return bestFeature

def majorityCnt(classList):
	''' 找出数量最多的分类 '''
	# 分类字典
	classCount = {}
	for vote in classList:
		if vote not in classCount.keys():
			classCount[vote] = 0
		classCount[vote] += 1

	# 以第二列的数据排序
	sortedClassCount = sorted(classCount.iteritems(), \
		key = operator.itemgetter(1), reverse = True)

	return sortedClassCount[0][0]

def createTree(dataSet, labels):
	''' 创建树的函数代码 '''
	classList = [example[-1] for example in dataSet] # 提取所有的类
	print "classList = ", classList

	# 数据集都是同一类的情况
	if classList.count(classList[0]) == len(classList):
		return classList[0]

	# 如果数据集只有一个特征的情况
	if len(dataSet[0]) == 1:
		return majorityCnt(classList) # 那就按大多数的分类

	bestFeat = chooseBestFeatureToSplit( dataSet ) # 最好的特征
	bestFeatLabel = labels[bestFeat] # 最好的分类

	myTree = {bestFeatLabel:{}}

	# 递归建树
	del( labels[bestFeat] )
	featValue = [example[bestFeat] for example in dataSet]
	uniqueVals = set(featValue) # 最好的特征集合

	for value in uniqueVals:
		subLabels = labels[:] # 去掉前面标签之后剩下的标签
		myTree[bestFeatLabel][value] = createTree(splitDataSet\
			(dataSet, bestFeat, value), subLabels)

	return myTree

def main():
	dataSet, labels = createDataSet()
	myTree = createTree(dataSet, labels)
	print myTree

if __name__=="__main__":
	main()

结果与上述图二一致:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
03-07 13:44