在我的例子中,对于连续变量的属性值二分点选择是通过先做百分点切割(切成一百份),然后循环的测试这些切割点的gini指数。
那么一个变量就要大约执行100次测试,普通的python代码执行都是单进程的,因此会阻塞等待,浪费大量的时间和算力。
将所有的切割点测试改为矩阵计算,那么每次选优的计算估计也就是原来5%不到的时间。
在实现上大约分为三步:
def cbcut(data=None, pstart=0.1, pend=0.9):
data = data.copy()
# 搜索的分位数在10~90之间
# 使用linspace计算每个分位点(0.1, 0.11,... 0.9)
bins = int((pend - pstart) * 100 + 1)
qlist = np.linspace(pstart, pend, bins)
# 分位数可能有重复,去重(unique去重后会排序)
qtiles = data.quantile(qlist).unique()
res_list = []
for q in qtiles:
data1 = data.apply(lambda x: 1 if x < q else 0)
res_list.append(data1)
res_dict = {
}
res_dict['data_list'] = res_list
res_dict['qtiles'] = qtiles
return res_dict
def cal_min_gini_mat(mat=None, y=0):
# 左子树
totals = len(y)
# left
left_total = (1 - mat).sum(axis=0)
left_weight = left_total / totals
left_p1 = (1-mat)@y / left_total
left_p0 = 1 - left_p1
l_gini = left_weight * (1 - left_p1 ** 2 - left_p0 ** 2)
# right
right_total = mat.sum(axis=0)
right_weight = right_total / totals
right_p1 = mat @y / right_total
right_p0 = 1 - right_p1
r_gini = right_weight * (1 - right_p1 ** 2 - right_p0 ** 2)
_gini = l_gini + r_gini
return min(_gini), _gini.argmin()
def Ctype_Search(x = None, y = None,pstart = 0.1,pend = 0.9):
bins = int((pend - pstart) * 100 + 1)
qlist = np.linspace(pstart, pend, bins)
qtiles = x.quantile(qlist).unique()
# 对C的取法
# 将qtiles升维
qtiles1 = np.expand_dims(qtiles, -1)
mat = (x.values < qtiles1) * 1
# X = mat.T # (99020, 26)
min_gini , min_pos = cal_min_gini_mat(mat=mat.T, y = y)
return min_gini, qtiles[min_pos]
下载数据
# 数据
In [156]: res_df.head()
Out[156]:
x y
157500.0 0
67500.0 0
90000.0 0
405000.0 1
90000.0 0
。。。
# >>>>>>>>>>>>>>>>>>>>>> before
import time
st = time.time()
iTree.find_min_gini(x=x, y=y, vartype='C', varname='AMT_INCOME_TOTAL')
print('It takes %.3f seconds' % time.time() - st)
In [13]: import time
...: st = time.time()
...: iTree.find_min_gini(x=x, y=y, vartype='C', varname='AMT_INCOME_TOTAL')
...: print('It takes %.3f seconds' % (time.time() - st))
It takes 1.665 seconds
In [158]: tree_res['AMT_INCOME_TOTAL'].keys()
Out[158]: dict_keys(['gini', 'condition_left', 'condition_right'])
In [159]: tree_res['AMT_INCOME_TOTAL']['gini']
Out[159]: 0.150674339064104
In [160]: tree_res['AMT_INCOME_TOTAL']['condition_left']
Out[160]: '<247500.0'
In [161]: tree_res['AMT_INCOME_TOTAL']['condition_right']
Out[161]: '>=247500.0'
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> after
In [162]: import time
...: st = time.time()
...: tree_res = Ctype_Search(x,y)
...: print('It takes %.3f seconds' % (time.time() - st))
It takes 0.032 seconds
In [163]: tree_res
Out[163]: (0.150674339064104, 247500.0)
快了50倍有木有?