熵,一个神奇的工具,用来衡量数据集信息量的不确定性。

  首先,我们先来了解一个指标,信息量。对于任意一个随机变量X,样本空间为{X1,X2,...,Xn},样本空间可以这么理解,也就是随机变量X所有的可能取值。如果在ML领域内,我们可以把Xi当做X所属的某一个类。对于任意的样本Xi(类Xi),样本Xi的信息量也就是l(Xi) = -log(p(Xi))。由于p(Xi)是为样本Xi的概率,也可以说是类Xi的概率,那么l(Xi)的取值范围为(+∞,0]。也就是X的的概率越小,其含有的信息量越大,反之亦然。这也不难理解,Xi的发生的概率越大,我们对他的了解越深,那么他所含有的信息量就越小。如果随机变量X是常量,那么我们从任意一个Xi都可以获取样本空间的值,那么我们可以认为X没有任何信息量,他的信息量为0。如果说,我们要把随机变量X样本空间都了解完,才能获得X的信息,那么我们可以认为X的信息量“无穷大”,其取值为log(2,n)。

  紧接着,我们就提出了随机变量X的信息熵,也就是信息量的期望,H(X) = -∑ni=1p(Xi)log2(p(Xi))=∑xp(x)log(p(x)),从离散的角度得出的公式。他有三个特性:

  • 单调性,即发生概率越高的事件,其所携带的信息熵越低。极端案例就是“太阳从东方升起”,因为为确定事件,所以不携带任何信息量。从信息论的角度,认为这句话没有消除任何不确定性。也就是样本空间的p(Xi)均为1。
  • 非负性,即信息熵不能为负。这个很好理解,因为负的信息,即你得知了某个信息后,却增加了不确定性是不合逻辑的。
  • 累加性,即多随机事件同时发生存在的总不确定性的量度是可以表示为各事件不确定性的量度的和。

  有了熵这个基础,那么我们就要考虑决策树是怎么生成的了。对于随机变量X的样本个数为n的空间,每个样本都有若干个相同的特征,假设有k个。对于任意一个特征,我们可以对其进行划分,假设含有性别变量,那么切分后,性别特征消失,变为确定的值,那么随机变量X信息的不确定性减少。以此类推,直到达到我们想要的结果结束,这样就生成了若干棵树,每棵树的不确定性降低。如果我们在此过程中进行限制,每次减少的不确定性最大,那么这样一步一步下来,我们得到的树,会最快的把不确定性降低到最小。每颗树的分支,都可以确定一个类别,包含的信息量极少,确定性极大,这种类别,是可以进行预测的,不确定性小,稳定,可以用于预测。

        有了以上的知识储备,那么我们要想生成一颗决策树,只需要每次把特征的信息量最大的那个找出来进行划分即可。也就是不确定性最大的那个分支,我们要优先划分。这就会得出另外一个定义,条件信息熵H(Y|X)。

 

根据以上的推导,我们得出信息增益,H(Y)-H(Y|X)。可以看做是特征X的信息量,根据这个的最大值,依次得到每个特征,就是我们需要的决策树。利用Python完成代如下,打包到一个类下面:

from math import log
import operator

# 计算香农熵
class Tree:
    def __init__(self):
        super()
    def calcShannonEnt(self, dataSet):
        num = len(dataSet)
        labelCounts = {}
        for fVec in dataSet:
            currentLabel = fVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key]) / num
            shannonEnt -= prob * log(prob, 2)
        return shannonEnt

    #按照特征划分数据集,特征的位置为index
    def splitDataSet(self, dataSet, index, value):
        retDataSet = []
        for featVec in dataSet:
            if featVec[index] == value:
                reducedFeatVec = featVec[:index]
                reducedFeatVec.extend(featVec[index+1:])
                retDataSet.append(reducedFeatVec)
        return retDataSet

    #寻找信息增益最大的特征
    def chooseBestFeatureToSplit(self, dataSet):
        numFeatures = len(dataSet[0]) - 1
        baseEntropy = self.calcShannonEnt(dataSet)
        bestInfoGain, bestFeature = 0.0, -1
        for i in range(numFeatures):
            featList = [example[i] for example in dataSet]
            uniqueVals = set(featList)
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy +=prob * self.calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
            if (infoGain >= bestInfoGain):#这里注意,取等号,只有1个特征为时,可能无信息增加。
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature

    # 如果分类不唯一,采用多数表决方法,决定叶子的分类
    def majorityCnt(self, classList):
        classCount = {}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            classCount[vote] += 1
        SortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
        return SortedClassCount[0][0]

    # 创建决策树代码
    def createTree(self, dataSet, labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0] == len(classList)):#类别完全相同,无需划分,一类
            return classList[0]
        if len(dataSet[0]) == 1: #处理了所有特征,依旧没有完全划分,返回多数表决结果
            return self.majorityCnt(classList)
        bestFeat = self.chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del labels[bestFeat]
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:#利用递归构建决策树
            subLabels = labels[:]
            myTree[bestFeatLabel][value] = self.createTree(self.splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree

    def createDataSet(self):
        dataSet = [
            [1,1,"yes"],
            [1,0,"no"],
            [0,1,"no"],
            [0,1,"no"]
        ]
        labels =["no surfacing", "flippers"]
        return dataSet, labels
    def decisiontreeclassify(self, inputTree, featLabels, testVec):
        firstStr = list(inputTree.keys())[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = self.decisiontreeclassify(secondDict[key],featLabels,testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
if __name__ == "__main__":
    tree = Tree()
    myDat, myLabels =tree.createDataSet()
    inputTree = tree.createTree(myDat, myLabels)
    featLabels = ['no surfacing','flippers']
    print(inputTree)
    print(tree.decisiontreeclassify( inputTree, featLabels, [1,0]))
    print(tree.decisiontreeclassify( inputTree, featLabels, [1,1]))

下面的代码,是画出决策树,便于查看,没有封装。

import matplotlib.pyplot as plt
# boxstyle是文本框类型 fc是边框粗细 sawtooth是锯齿形
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# annotate 注释的意思
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs +=getNumLeafs(secondDict[key])
        else:
            numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = getTreeDepth(secondDict[key]) +1
        else:
            thisDepth =1
        if thisDepth >maxDepth:
            maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):#创建树
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString):#填充文本
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#i建数据集和标签
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD     #按比例减少全局变量plotTree.yOff
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW

            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制此节点带箭头的注解
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  #绘制此节点带箭头的注解
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD    #按比例增加全局变量plotTree.yOff

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
if __name__=="__main__":
    myTree = retrieveTree(0)
    list(myTree.keys())[0]
    createPlot(myTree)

 

 

内容来源于网络如有侵权请私信删除

文章来源: 博客园

原文链接: https://www.cnblogs.com/zhuangxp2008/p/13875738.html

你还没有登录,请先登录注册
  • 还没有人评论,欢迎说说您的想法!