本文为第一次数据挖掘实战,数据预处理部分较乱,源码略去,直接给出处理后数据,所用算法为ID3
# -*- coding: utf-8 -*-
"""
Created on Sun Mar 4 17:08:06 2018
@author: CSH
"""
import math
import operator
from numpy import*
import pandas as pd
import csv
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel=featVec[-1]
# labelCounts[currentLabel]=labelCounts.get(currentLabel,0)+1
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*math.log(prob,2)
return shannonEnt
# 读取csv文件方式1
csvFile = open("qudiaonanrowTrain1.csv", "r")
reader = csv.reader(csvFile) # 返回的是迭代类型
myDat = []
for item in reader:
# print(item)
myDat.append(item[1:])
del(myDat[0])
#print(data)
csvFile.close()
#myDat=pd.read_csv('qudiaonanrowTrain1.tsv',encoding='ANSI',sep='\t')
mylabels=['AGENT','IS_LOCAL','VORK_PROVINC','EDU_LEVEL','MARRY_STATUS','SALARY','HAS_FUND']
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#a=splitDataSet(myDat,0,1)
#print(a)
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1
baseEntropy=calcShannonEnt(dataSet)
bestInfoGain=0.0;bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataSet]
uniqueVals=set(featList)
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calcShannonEnt(subDataSet)
infoGain=baseEntropy-newEntropy
if (infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature
#a=chooseBestFeatureToSplit(myDat)
#print(a)
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
for value in uniqueVals:
subLabels=labels[:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
myTree=createTree(myDat[:5000],mylabels)
#print(myTree)
#classLabel=-1
def classify(inputTree,featLabels,testVec):
# global classLabel
firstSides=list(inputTree.keys())
firstStr=firstSides[0]
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featLabels,testVec)
else:
classLabel=secondDict[key]
# print(classLabel)
return classLabel
#labels=createDataSet()[1]
myLabels=['AGENT','IS_LOCAL','VORK_PROVINC','EDU_LEVEL','MARRY_STATUS','SALARY','HAS_FUND']
#a=classify(myTree,myLabels,['APP','本地籍','320000.0','本科','已婚','3.0','0.0'])
#a=classify(myTree,myLabels,['wechat', '本地籍', '420000.0', '专科及以下', '未婚', '4.0', '1.0'])
#print(a)
# =============================================================================
# def storeTree(inputTree,filename):
# import pickle
# fw=open(filename,'wb')
# pickle.dump(inputTree,fw)
# fw.close()
# def grabTree(filename):
# import pickle
# fr=open(filename,'rb')
# return pickle.load(fr)
#
# storeTree(myTree,'classifierStorage.txt')
# b=grabTree('classifierStorage.txt')
# #print(b)
# =============================================================================
#del(myDat[2])
def testing(myTree,data_test,labels):
error=0.0;a=0
testCount=shape(data_test)[0]
for i in range(len(data_test)):
try:
if classify(myTree,labels,data_test[i])!=data_test[i][-1]:
error+=1.0
except:
a+=1
continue
print('errorRate %d' %float(error/testCount),a)
return float(error)
# =============================================================================
# # 读取csv文件方式2
# with open("mostnan.csv", "r") as csvfile:
# reader2 = csv.reader(csvfile) # 读取csv文件,返回的是迭代类型
# data = []
# for item2 in reader2:
# # print(item2)
# data.append(item[1:])
# del(data[0])
# csvFile.close()
#
# =============================================================================
testing(myTree,myDat[5000:],myLabels)
附:链接:https://pan.baidu.com/s/1LkwdWC67uBhB1kuWCikF-Q 密码:xbm3
相关信用评分模型博客:https://blog.csdn.net/lll1528238733/article/category/7072659