引入数据库
import pandas as pd
import numpy as np
import math
import random
df = pd.read_csv('Adult.csv')
df
age | workclass | fnlwgt | education | education.num | marital.status | occupation | relationship | race | sex | capital.gain | capital.loss | hours.per.week | native.country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 90 | ? | 77053 | HS-grad | 9 | Widowed | ? | Not-in-family | White | Female | 0 | 4356 | 40 | United-States | <=50K |
1 | 82 | Private | 132870 | HS-grad | 9 | Widowed | Exec-managerial | Not-in-family | White | Female | 0 | 4356 | 18 | United-States | <=50K |
2 | 66 | ? | 186061 | Some-college | 10 | Widowed | ? | Unmarried | Black | Female | 0 | 4356 | 40 | United-States | <=50K |
3 | 54 | Private | 140359 | 7th-8th | 4 | Divorced | Machine-op-inspct | Unmarried | White | Female | 0 | 3900 | 40 | United-States | <=50K |
4 | 41 | Private | 264663 | Some-college | 10 | Separated | Prof-specialty | Own-child | White | Female | 0 | 3900 | 40 | United-States | <=50K |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
32556 | 22 | Private | 310152 | Some-college | 10 | Never-married | Protective-serv | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
32557 | 27 | Private | 257302 | Assoc-acdm | 12 | Married-civ-spouse | Tech-support | Wife | White | Female | 0 | 0 | 38 | United-States | <=50K |
32558 | 40 | Private | 154374 | HS-grad | 9 | Married-civ-spouse | Machine-op-inspct | Husband | White | Male | 0 | 0 | 40 | United-States | >50K |
32559 | 58 | Private | 151910 | HS-grad | 9 | Widowed | Adm-clerical | Unmarried | White | Female | 0 | 0 | 40 | United-States | <=50K |
32560 | 22 | Private | 201490 | HS-grad | 9 | Never-married | Adm-clerical | Own-child | White | Male | 0 | 0 | 20 | United-States | <=50K |
32561 rows × 15 columns
处理缺失值
df1 = df[(df['workclass']!='?')&(df['occupation']!='?')&(df['native.country']!='?')]
df1
age | workclass | fnlwgt | education | education.num | marital.status | occupation | relationship | race | sex | capital.gain | capital.loss | hours.per.week | native.country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 82 | Private | 132870 | HS-grad | 9 | Widowed | Exec-managerial | Not-in-family | White | Female | 0 | 4356 | 18 | United-States | <=50K |
3 | 54 | Private | 140359 | 7th-8th | 4 | Divorced | Machine-op-inspct | Unmarried | White | Female | 0 | 3900 | 40 | United-States | <=50K |
4 | 41 | Private | 264663 | Some-college | 10 | Separated | Prof-specialty | Own-child | White | Female | 0 | 3900 | 40 | United-States | <=50K |
5 | 34 | Private | 216864 | HS-grad | 9 | Divorced | Other-service | Unmarried | White | Female | 0 | 3770 | 45 | United-States | <=50K |
6 | 38 | Private | 150601 | 10th | 6 | Separated | Adm-clerical | Unmarried | White | Male | 0 | 3770 | 40 | United-States | <=50K |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
32556 | 22 | Private | 310152 | Some-college | 10 | Never-married | Protective-serv | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
32557 | 27 | Private | 257302 | Assoc-acdm | 12 | Married-civ-spouse | Tech-support | Wife | White | Female | 0 | 0 | 38 | United-States | <=50K |
32558 | 40 | Private | 154374 | HS-grad | 9 | Married-civ-spouse | Machine-op-inspct | Husband | White | Male | 0 | 0 | 40 | United-States | >50K |
32559 | 58 | Private | 151910 | HS-grad | 9 | Widowed | Adm-clerical | Unmarried | White | Female | 0 | 0 | 40 | United-States | <=50K |
32560 | 22 | Private | 201490 | HS-grad | 9 | Never-married | Adm-clerical | Own-child | White | Male | 0 | 0 | 20 | United-States | <=50K |
30162 rows × 15 columns
获取数据集中一列的所有取值,例如sex中的male 和 female
def get_all_category(line):
line_class=list()
for i in line:
if i not in line_class:
line_class.append(i)
return line_class
判断两个链表是否相同
def equeal_list(list1,list2):
if len(list1) != len(list2):
return False
for j in range(len(list1)):
if list1[j] != list2[j]:
return False
return True
def equeal_list_not_in_list(list1,list2):
if len(list2) == 0:
return True
for i in range(len(list2)):
if equeal_list(list1,list2[i]):
return False
return True
E n t ( x ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 ( p k ) Ent(x)=-\sum_{k=1}^{|y|}{p_klog2(p_k)} Ent(x)=−∑k=1∣y∣pklog2(pk)
def Ent(data):
prob=pd.value_counts(data)/len(data)#计算某列不同值,并得出重复值
return sum(np.log2(prob)*prob*-1)
G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a)=Ent(D)-\sum_{v=1}^V{\frac{|D^v|}{|D|}Ent(D^v)} Gain(D,a)=Ent(D)−∑v=1V∣D∣∣Dv∣Ent(Dv)
def Gain(data,str1,str2):
e1=data.groupby(str1).apply(lambda x:Ent(x[str2]))
p1=pd.value_counts(data[str1])/len(data[str2])
e2=sum(e1*p1)
return Ent(data[str2])-e2
def get_best_spilt(data,str2):#data为该分支下的数据集,str2为二分类判断特征
features=list()
n_features=list()
for k in data.columns:#所有特征
n_features.append(k)
n_features.remove(str2)#除去二分类特征
k=math.ceil(math.log2(len(n_features)))
features=random.sample(n_features,k)#取k个特征用来选取最优特征
l=list()
for s in features:
g=Gain(data,s,str2)
l.append(g)
return n_features[l.index(max(l))]
class Tree():
def __init__(self,node,depth,is_leaf,spilt,node_df,label,tree_list,all_node):
self.node=node#当前结点的标签名称----------String
self.depth=depth#当前结点深度---------int
self.is_leaf=is_leaf#当前结点是否为叶结点-------int 0 1
self.spilt=spilt#当前结点的最佳特征--------String
self.node_df=node_df#当前结点包含的数据集
self.label=label#当前结点的二分类结果,需要判断是否为叶结点,如果不是叶结点则赋值为2
self.tree_list=tree_list#当前结点下的所有子树
self.all_node=all_node#当前已经用了的特征
1、子节点中的样本属于同一类
2、子节点无样本
3、特征已经用完
def is_leaf_now(data,str2):
pol=pd.value_counts(data[str2])/len(data[str2])
if len(data[data.columns[0]])==0:
return False
elif pol[0]==1:
return False
elif len(data.columns)==1:
return False
else:
return True
def built_tree(data,str2,Tnode,k,all_node):
now_tree=Tree('',0,0,'',pd.DataFrame(),2,list(),list())#新建一棵树
now_tree.node=Tnode#当前结点的标签,如male
now_tree.node_df=data
now_tree.depth=k+1
# print('is_leaf_now:',is_leaf_now(data,str2))
if is_leaf_now(data,str2):
now_spilt=get_best_spilt(data,str2)#最佳特征
now_tree.spilt=now_spilt
now_tree.all_node=all_node
#获取最佳特征下的所有取值可能
all_category=get_all_category(now_tree.node_df[now_tree.spilt])
for i in all_category:
now_tree.all_node.append(now_tree.spilt)#为了判断树是否一样
now_data=now_tree.node_df[now_tree.node_df[now_tree.spilt]==i]#取该属性的所有数据
del now_data[now_tree.spilt]#删除该特征列的值,不在作为下一步计算
new_tree=built_tree(now_data,str2,i,now_tree.depth,now_tree.all_node)
now_tree.tree_list.append(new_tree)
else:
now_tree.is_leaf=1
b=str(data[str2].head(1))
la='<=50K'
lb='>50K'
la_result=la in b
lb_result=lb in b
# print('la_result:',la_result)
# print('lb_result:',lb_result)
if la_result:
now_tree.label=-1
elif lb_result:
now_tree.label=1
return now_tree
def random_forest(train,str2):
trees=list()
already=list()
#先new一颗tree
Tnode='tree_root'
depth=0
all_node=list()
tree=built_tree(train,str2,Tnode,depth,all_node)
while len(trees)<=(6*5*4*3*2):
if equeal_list_not_in_list(tree.all_node,already):#该tree和之前的树没有重复
already.append(tree.all_node)
trees.append(tree)
all_node=list()
tree=built_tree(train,str2,Tnode,depth,all_node)
return trees
def is_no_var(now_spilt,data,thevar):
all_catagory=get_all_category(data[now_spilt])
if thevar not in all_catagory:
return True
return False
def get_tree_var(thevar,tree):
for i in range(len(tree.tree_list)):
if thevar == tree.tree_list[i].node:
return tree.tree_list[i]
def Tree_result(tree,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,hours_per_week,native_country):
now_tree=tree
label=0
while now_tree.is_leaf==0:
# print(now_tree.node_df)
now_spilt=now_tree.spilt
if now_spilt=='age':
#如果该分类下没有这个取值返回0
if is_no_var(now_spilt,now_tree.node_df,age):
return 0
#得到age值对应的树
now_tree=get_tree_var(age,now_tree)
elif now_spilt=='workclass':
if is_no_var(now_spilt,now_tree.node_df,workclass):
return 0
now_tree=get_tree_var(workclass,now_tree)
elif now_spilt=='fnlwgt':
if is_no_var(now_spilt,now_tree.node_df,fnlwgt):
return 0
now_tree=get_tree_var(fnlwgt,now_tree)
elif now_spilt=='education':
if is_no_var(now_spilt,now_tree.node_df,education):
return 0
now_tree=get_tree_var(education,now_tree)
elif now_spilt=='eduction.num':
if is_no_var(now_spilt,now_tree.node_df,education_num):
return 0
now_tree=get_tree_var(education_num,now_tree)
elif now_spilt=='marital.status':
if is_no_var(now_spilt,now_tree.node_df,marital_status):
return 0
now_tree=get_tree_var(marital_status,now_tree)
elif now_spilt=='occupation':
if is_no_var(now_spilt,now_tree.node_df,occupation):
return 0
now_tree=get_tree_var(occupation,now_tree)
elif now_spilt=='relationship':
if is_no_var(now_spilt,now_tree.node_df,relationship):
return 0
now_tree=get_tree_var(relationship,now_tree)
elif now_spilt=='race':
if is_no_var(now_spilt,now_tree.node_df,race):
return 0
now_tree=get_tree_var(race,now_tree)
elif now_spilt=='sex':
if is_no_var(now_spilt,now_tree.node_df,sex):
return 0
now_tree=get_tree_var(sex,now_tree)
elif now_spilt=='hours.per.week':
if is_no_var(now_spilt,now_tree.node_df,hours_per_week):
return 0
now_tree=get_tree_var(hours_per_week,now_tree)
elif now_spilt=='native.country':
if is_no_var(now_spilt,now_tree.node_df,native_country):
return 0
now_tree=get_tree_var(native_country,now_tree)
# print(now_tree.label)
return now_tree.label
def random_forest_result(trees,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,hours_per_week,native_country):
k=0
for i in range(len(trees)):
# print(Tree_result(trees[i],age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,hours_per_week,native_country))
k=k+Tree_result(trees[i],age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,hours_per_week,native_country)
return k
del df1['capital.gain']
del df1['capital.loss']
del df1['education.num']
del df1['native.country']
del df1['race']
del df1['hours.per.week']
del df1['sex']
del df1['fnlwgt']
data=df1.head(100)
data
age | workclass | education | marital.status | occupation | relationship | income | |
---|---|---|---|---|---|---|---|
1 | 82 | Private | HS-grad | Widowed | Exec-managerial | Not-in-family | <=50K |
3 | 54 | Private | 7th-8th | Divorced | Machine-op-inspct | Unmarried | <=50K |
4 | 41 | Private | Some-college | Separated | Prof-specialty | Own-child | <=50K |
5 | 34 | Private | HS-grad | Divorced | Other-service | Unmarried | <=50K |
6 | 38 | Private | 10th | Separated | Adm-clerical | Unmarried | <=50K |
... | ... | ... | ... | ... | ... | ... | ... |
108 | 50 | Private | Bachelors | Married-civ-spouse | Exec-managerial | Wife | >50K |
109 | 47 | Private | Bachelors | Married-civ-spouse | Exec-managerial | Husband | >50K |
110 | 47 | Self-emp-inc | Prof-school | Married-civ-spouse | Prof-specialty | Husband | >50K |
111 | 67 | Private | Bachelors | Widowed | Exec-managerial | Not-in-family | >50K |
112 | 67 | Self-emp-inc | Bachelors | Married-civ-spouse | Exec-managerial | Husband | >50K |
100 rows × 7 columns
trees=list()
trees=random_forest(data,'income')
print(len(trees))
721
k=random_forest_result(trees,82,'Self-emp-inc','fnlwgt','Some-college','education_num','Separated','Exec-managerial','Own-child','race','sex','hours_per_week','native_country')
# print('OK!!!')
print(k)
74