decision tree


Project Address:

https://github.com/TheOneAC/ML.git

dataset in ML/ML_ation/tree

### 决策树

  • 计算复杂度低,中间值缺失不敏感,可理解不相关数据
  • 可能过度匹配(过度分类)
  • 适用:数值型和标称型

决策树伪代码createbranch

1
2
3
4
5
6
7
检测数据集中子项是否全部属于一类
if so return class_tag
else 寻找数据集最佳划分特征
划分数据集
创建分支节点
对每一个子集,递归调用createbranch
返回分支节点

递归结束条件:所有属性遍历完,或者数据集属于同一分类

香农熵

1
2
3
4
5
6
7
8
9
10
11
12
13
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt

数据及划分与最优选择(熵最小)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reduceFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])- 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
baseInfoGain = infoGain
bestFeature = i
return bestFeature

所有标签用尽无法确定类标签时: 多数表决决定子叶分类

1
2
3
4
5
6
7
8
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]

创建树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatureLabel = labels[bestFeat]
myTree = {bestFeatureLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value), subLabels)
return myTree

测试

1
2
3
4
5
6
7
8
9
10
11
12
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]
return classLabel
1
2
3
4
5
6
7
8
9
10
11
>>> import trees
>>> myDat,labels=trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=treePlotter.retrieveTree (0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree,labels,[1,0])
'no'
>>> trees.classify(myTree,labels,[1,1])
'yes'

### 存储与重载

1
2
3
4
5
6
7
8
9
10
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)

### test

1
2
3
4
5
6
7
8
9
10
#!/usr/bin/python
import trees
myDat,labels = trees.createDataSet()
myTree = trees.createTree(myDat, labels)
trees.storeTree(myTree,'classifierStorage.txt')
print(trees.grabTree('classifierStorage.txt'))

图形化显示树结构

1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/python
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = "axes fraction",
xytext = centerPt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

创建节点

1
2
3
4
5
6
7
def createPlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False)
plotNode("a decision node",(0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode("a leaf node",(0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()

python command line run command as this

1
2
import treeplotter
treePlotter.createPlot()

  • result like this
    图片标题
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
    numLeafs += getNumleafs(secondDict[key])
    else: numLeafs +=1
    return numLeafs
    def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
    thisDepth = 1+ getTreeDepth(secondDict[key])
    else:
    thisDepth = 1
    if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth
    def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': \
    {0: 'no', 1: 'yes'}}}},
    {'no surfacing': {0: 'no', 1: {'flippers': \
    {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    ]
    return listOfTrees[i]
    def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)
    def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\
    plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
    plotTree(secondDict[key],cntrPt,str(key))
    else:
    plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
    plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
    cntrPt, leafNode)
    plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

图片标题

扩展测试 lens.py

1
2
Project Address: ` https://github.com/TheOneAC/ML.git`
dataset: `lens.txt in ML/ML_ation/tree`
1
2
3
4
5
6
7
8
9
10
11
12
13
#!/usr/bin/python
import trees
import treePlotter
fr = open("lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses,lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)

图片标题