获取频繁项集和关联规则的Python实现【先验算法】

# -*- coding: utf-8 -*-
#参数设定
data_file = 'F:\\user_match_stat\\itemset.txt'
#文件格式csv,形如:item1,item2,item3
#每个事务占一行
frequent_itemsets_save_file = 'F:\\user_match_stat\\frequent_itemsets.txt'
rules_readable_file_dest = 'F:\\user_match_stat\\rules_readable.txt'
rules_csv_file_dest = 'F:\\user_match_stat\\rules_csv.txt'
rules_ranked_desc_by_liftrate = 'F:\\user_match_stat\\rules_liftrate_desc.txt'
#格式:itemset A,itemset B,support,confidence,liftrate
#itemset的项之间用|分隔
minsup = 0.01   #最小支持度,所有规则的支持度需要大于等于此值
minconf = 0.000001  #最小置信度
 
 
#通过计算得到的统计量
transaction_cnt = 0  #总事务数
min_sup_cnt = 0     #最小支持记数
transaction_cnt_distinct = 0 #总不同事务数
 
        
#全局数据结构
transaction_cnt_dict = {} # dict(tuple , int)  
frequent_itemsets_verified = {}  #dict(tuple , int)
frequent_itemsets_candidate = {}  #list(tuple , [set])
frequent_itemsets = {} #dict( int , dict(tuple , int) ) 保存所有的频率项集,第一个int是项集的长度
closed_frequent_itemsets = {} #保存所有闭频率项集
distinct_item_in_candidate_itemsets = set()
distinct_item_in_transaction_cnt_dict = set()
item_transaction_list_dict = {} # {元素 , set(tranidx in transaction_cnt_dict)}
hitted_transaction_set = set()
 
 
#获取事务集
def prepare_data() :
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    file = open(data_file)
    print 'Reading data from ' + data_file + '...'
    
    pre_transaction_cnt_dict = {}
    n = 0
    for line in file :   #读取事务列表
        line = line.strip()   #不然会有\n
        if line == '' :
            continue
        
        n = n + 1
        item_list = line.split(',')
        item_list.sort()
        tp = tuple(item_list)
        
        if tp in pre_transaction_cnt_dict :
            pre_transaction_cnt_dict[tp] = pre_transaction_cnt_dict[tp] + 1
        else :
            pre_transaction_cnt_dict[tp] = 1
        
        #测试用    
        #if n > 20000000000 :
        #    break
    
    #总事务数        
    transaction_cnt = n
    print 'Totally read ' + str(n) + ' lines.'
    file.close()
    
    #初始化transaction_cnt_dict和item_transaction_list_dict
    tranidx = 1
    for tp in pre_transaction_cnt_dict :
        transaction_cnt_dict[tranidx] = pre_transaction_cnt_dict[tp]
        for item in tp :
            if item in item_transaction_list_dict :
                item_transaction_list_dict[item].add(tranidx)
            else :
                item_transaction_list_dict[item] = set((tranidx,))
        tranidx = tranidx + 1
        
    del pre_transaction_cnt_dict
    
    transaction_cnt_distinct = len(transaction_cnt_dict)    
    min_sup_cnt = int(transaction_cnt * minsup)
    print 'The number of total transactions is '+str(transaction_cnt) + '.'
    print 'The number of distinct transactions is '+str(transaction_cnt_distinct) + '.'
    print 'The min support count is '+str(min_sup_cnt) + '.'
    print 'Function prepare_data done.'
    return
    
    
#得到频繁一项集,直接从item_transaction_list_dict里面统计即可        
def get_frequent_itemsets_1() :
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    hitted_transaction_set.clear()
    frequent_itemsets[1] = {}
    for item in item_transaction_list_dict :
        #cnt = len(item_transaction_list_dict[item]) 严重错误
        cnt = 0
        for tranidx in item_transaction_list_dict[item] :
            cnt = cnt + transaction_cnt_dict[tranidx]
        
        if cnt >= min_sup_cnt :
            frequent_itemsets[1][(item,)] = cnt
            #fill hitted_transaction_set    
            for tranidx in item_transaction_list_dict[item] :  
                hitted_transaction_set.add(tranidx)
   
    print 'Function get_frequent_itemsets_1 done'
    return
 
        
#获取候选项,根据frequent_itemsets_verified填充frequent_itemsets_candidate,清空frequent_itemsets_verified
#用Fk-1 * Fk-1法
#如果得不到新的K项,返回-1
def get_candidates(k):
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    frequent_itemsets_candidate.clear()
    #为每一个项生成一个序号,只遍历比自己序号大的项
    transaction_cnto_lay1 = 0
    for tp_out in frequent_itemsets_verified :
        transaction_cnto_lay1 = transaction_cnto_lay1 + 1
        transaction_cnto_lay2 = 0
        for tp_in in frequent_itemsets_verified :
            transaction_cnto_lay2 = transaction_cnto_lay2 + 1
            if transaction_cnto_lay2 > transaction_cnto_lay1 :
                if k == 2 :#长度为1的时候单独处理
                    #把大的放后面
                    #保证项集的有序性
                    if tp_out[0] > tp_in[0]:
                        tmp_tuple = tp_in + tp_out
                        frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)
                    else:
                        tmp_tuple = tp_out + tp_in
                        frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)
                else :
                    #比较前K-2项,如果全部相同,则产生一个K项
                    if tp_out[:-1] == tp_in[:-1] :
                        #把大的放后面
                        #保证项集的有序性
                        if tp_out[-1] > tp_in[-1] :
                            tmp_tuple = tp_out[:-1] + tp_in[-1:] + tp_out[-1:]
                            frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)
                        else :
                            tmp_tuple = tp_out[:-1] + tp_out[-1:] + tp_in[-1:]
                            frequent_itemsets_candidate[tmp_tuple] = set(tmp_tuple)               
    
    if len(frequent_itemsets_candidate) == 0 :
        return -1
    
    #通过frequent_itemsets_verified剪枝
    #检查第一个候选项的所有子集是否都在frequent_itemsets_verified中
    #非常有用,用减少60%以上的候选
    if k != 2 :
        del_list = []
        for tp in frequent_itemsets_candidate :
            for i in range(0 , len(tp)):
                test_tp = tp[:i] + tp[i+1:]
                if test_tp not in frequent_itemsets_verified :
                    del_list.append(tp)
                    break
                    
        print '-------------------------------------------------'
        print '........Total ' + str(len(frequent_itemsets_candidate)) + ' candidates before cut.'            
        print '........Cut ' + str(len(del_list)) + ' candidates.'
        print '-------------------------------------------------'
        
        for tp in del_list :
            del frequent_itemsets_candidate[tp]
    
    if len(frequent_itemsets_candidate) == 0 :
        return -1
        
    frequent_itemsets_verified.clear()
    return 0
 
 
#检查frequent_itemsets_candidate的支持度计数,将符合条件的插入到frequent_itemsets_verified
#清空frequent_itemsets_candidate    
#如果检查全部是不频繁的,返回-1
def check_candidates_1():
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    print 'Start check candidates.'
    total_candidates = len(frequent_itemsets_candidate)
    print 'Total ' + str(total_candidates) + ' candidates need to check.'
    
    frequent_itemsets_verified.clear()    
    hitted_transaction_set.clear()
    cnt = 0            
    pct = 0
    for tp in frequent_itemsets_candidate :
        #打印进度
        cnt = cnt + 1
        new_pct = cnt*100/total_candidates
        if new_pct != pct :
            pct = new_pct
            print str(pct) + '%'
        
        tmp_set = None
        for item in tp :
            if item not in item_transaction_list_dict :
                print 'Error!!!'
            else :
                if tmp_set is None :
                    tmp_set = item_transaction_list_dict[item]
                else :
                    tmp_set = tmp_set & item_transaction_list_dict[item]
                 
        suport_cnt = 0   
        if len(tmp_set) != 0 :
            for ele in tmp_set :
                suport_cnt = suport_cnt + transaction_cnt_dict[ele]  
 
        if suport_cnt >= min_sup_cnt :
            frequent_itemsets_verified[tp] = suport_cnt
            #记录命中的事务项,在过滤事务时去掉不在这里面的事务
            for ele in tmp_set :
                hitted_transaction_set.add(ele)
            
    frequent_itemsets_candidate.clear()
    if len(frequent_itemsets_verified) == 0 :
        return -1
    print 'Finish check candidates.'
    return 0   
    
#将frequent_itemsets_verified中的内容append到frequent_itemsets
def save_frequent_itemsets(k):
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    #去掉k-1级中的非闭频繁项集
    del_list_k_1 = []
    if k - 1 >= 1 :
        for tp in frequent_itemsets[k - 1] :
            sup = frequent_itemsets[k - 1][tp]
            for tp_in in frequent_itemsets_verified :
                if sup == frequent_itemsets_verified[tp_in] : #这里也许可以把条件放宽一点,不一定要绝对相等
                    #print '........................enter'
                    if set(tp).issubset(set(tp_in)) :
                        del_list_k_1.append(tp)
                        break
 
    print '-------------------------------------------------'
    print '...Cutting unclosed frequent itemsets in k = ' + str(k - 1) + '.'
    print '........Total ' + str(len(frequent_itemsets[k - 1])) + ' itemsets before cut.'            
    print '........Cut ' + str(len(del_list_k_1)) + ' itemsets for not closed.'
    print '-------------------------------------------------'
    
    for tp_del in del_list_k_1 :
        del frequent_itemsets[k - 1][tp_del]   
    
    if k not in frequent_itemsets :
        frequent_itemsets[k] = {}
    for tp in frequent_itemsets_verified :
        frequent_itemsets[k][tp] = frequent_itemsets_verified[tp]
    return
    
#得到不同项的数量
def get_distinct_item_in_candidate_itemsets() :
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    distinct_item_in_candidate_itemsets.clear()
    for tp in frequent_itemsets_candidate :
        for item in tp :
            distinct_item_in_candidate_itemsets.add(item)
    return
    
##过滤掉不再有用的数据,以减少计算量
def filter_data(k) :
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    print 'Function filter_data begin.'
    
    #在裁剪数据之前统计
    print '---------------------------------------------------'
    print '...Stat data before cut data.'
    item_num = len(item_transaction_list_dict)
    print '......Total ' + str(item_num) + ' items in item_transaction_list_dict.'
    tmp_num = 0L
    for item in item_transaction_list_dict :
        tmp_num = tmp_num + len(item_transaction_list_dict[item])
    print '......The average length of transaction set for each item is ' + str(round(tmp_num/item_num)) + '.'
    print '......The number of total transactions is ' + str(len(transaction_cnt_dict)) + '.'
    print '---------------------------------------------------'
    
    #直接重构数据
    tran_del_list = set() #测试用
    item_del_list = set() #测试用
    tranidx_itemset_dict = {}   #{tranidx : set(item)}
    for item in item_transaction_list_dict :
        if item in distinct_item_in_candidate_itemsets : #只考虑在候选中出现的item
            for tranidx in item_transaction_list_dict[item] :
                if tranidx in hitted_transaction_set : #只考虑上次命中的
                    if tranidx in tranidx_itemset_dict :
                        tranidx_itemset_dict[tranidx].add(item)
                    else :
                        tranidx_itemset_dict[tranidx] = set((item, ))
                else :#对于上次没有命中的,不再考虑
                    tran_del_list.add(tranidx) #测试用
        else : #如果item在候选中没有出现
            item_del_list.add(item)
    print '...' + str(len(item_del_list)) + ' items were cut for no appearence in candidates.'
    print '...' + str(len(tran_del_list)) + ' transactions were cut for no match in k-1 level.'
    
    new_itemset_cnt_dict = {}  #{tuple(tra) , cnt} , 用它来重构数据
    merge_cnt = 0
    lt_k_cnt = 0
    for tranidx in tranidx_itemset_dict :
        if len(tranidx_itemset_dict[tranidx]) >= k : #只取项数大于K的(即与候选项的交集大于等于K)
            tp = tuple(tranidx_itemset_dict[tranidx])
            if tp in new_itemset_cnt_dict :
                merge_cnt = merge_cnt + 1
                new_itemset_cnt_dict[tp] = new_itemset_cnt_dict[tp] + transaction_cnt_dict[tranidx]
            else :
                new_itemset_cnt_dict[tp] = transaction_cnt_dict[tranidx]
        else :
            lt_k_cnt = lt_k_cnt + 1
    del tranidx_itemset_dict  #不再有用,删掉
    print '...' + str(lt_k_cnt) + ' transactions were cut for item number less than k.'
    print '...' + str(merge_cnt) + ' transactions were cut for merge.'
 
 
    transaction_cnt_dict.clear()
    item_transaction_list_dict.clear()
    tranidx = 1
    for tp in new_itemset_cnt_dict :
        transaction_cnt_dict[tranidx] = new_itemset_cnt_dict[tp]
        for item in tp :
            if item in item_transaction_list_dict :
                item_transaction_list_dict[item].add(tranidx)
            else :
                item_transaction_list_dict[item] = set((tranidx,))  
        tranidx = tranidx + 1
    del new_itemset_cnt_dict     
    
    #在裁剪数据之后统计
    print '---------------------------------------------------'
    print '...Stat data after cut data.'
    item_num = len(item_transaction_list_dict)
    print '......Total ' + str(item_num) + ' items in item_transaction_list_dict.'
    tmp_num = 0L
    for item in item_transaction_list_dict :
        tmp_num = tmp_num + len(item_transaction_list_dict[item])
    print '......The average length of transaction set for each item is ' + str(round(tmp_num/item_num)) + '.'
    print '......The length of transaction_cnt_dict is ' + str(len(transaction_cnt_dict)) + '.'
    print '---------------------------------------------------'
    
    print 'Function filter_data done.'
    return
 
 
def get_rules() :
    #对于每一个频繁项集L,如果support_cnt L / support_cnt S >= min_conf ,则输出 S =》 (L - S)
    #其中S是L的真子集
    
    global transaction_cnt_dict
    global frequent_itemsets_verified
    global frequent_itemsets_candidate
    global frequent_itemsets
    global distinct_item_in_candidate_itemsets
    global distinct_item_in_transaction_cnt_dict
    global transaction_cnt
    global min_sup_cnt
    global transaction_cnt_distinct
    global item_transaction_list_dict
    global hitted_transaction_set
    
    #最好直接输出到文件,不然结构太大
    #rule_dict = {}  #{tp L : {tp S , (conf , lift )}}
    
    file_rule_readable = open(rules_readable_file_dest , 'w')
    file_rule_csv = open(rules_csv_file_dest , 'w')
    
    rule_list = [] #[{rulestr:liftrate}]
    
    #计算一共有多少层
    layer_cnt = len(frequent_itemsets)
    
    #从最外层开始,到第二层
    for kk in range(layer_cnt , 1 , -1 ) :
        for itemset_kk in frequent_itemsets[kk] :
            set_itemset_kk = set(itemset_kk)
            sup_kk = round((frequent_itemsets[kk][itemset_kk] + 0.0 )/ transaction_cnt , 4 )
            #遍历一遍K-1到1级频项,
            for kkk in range(kk-1 , 0 , -1) :
                for itemset_kkk in frequent_itemsets[kkk] :
                    set_itemset_kkk = set(itemset_kkk)
                    #如果是set_itemset_kk的真子集,则输出规则
                    if set_itemset_kk.issuperset(set_itemset_kkk) :
                        tmp_conf = round((frequent_itemsets[kk][itemset_kk] + 0.0) / frequent_itemsets[kkk][itemset_kkk] , 4 )
                        #只考虑大于最小置信度的情况
                        if tmp_conf >= minconf :
                            #提升度等于(tmp_conf - 原支持度)/ 原支持度
                            #提升度的正确定义是 P(A 交 B) / P(A)*P(B) ,即A与B同时发生的概率 与 A与B独立的情况下同时发生的概率的比值
                            #相当于 tmp_conf / 原支持度
                            #取值范围是0到无穷大 , 小于1说明两者互斥 ,大于1说明两者的发生互有提升
                            set_dest = set_itemset_kk - set_itemset_kkk
                            list_dest = list(set_dest)
                            list_dest.sort()
                            tp_dest = tuple(list_dest)
                            tmp_lift = 0
                            tmp_sup = 0
                            tmp_length = len(tp_dest)
                            if tp_dest in frequent_itemsets[tmp_length] :
                                tmp_sup = round((frequent_itemsets[tmp_length][tp_dest] + 0.0) / transaction_cnt , 4 )
                                tmp_lift = round( tmp_conf / tmp_sup , 4 )
                            else :
                                tmp_lift = None
                                tmp_sup = None
                            
                            #输出规则
                            #itemset_kkk 》》 set_dest , with support  sup_kk , confidence tmp_conf , lift tmp_list , orirate tmp_sup
                            tmp_str = str(itemset_kkk) + '  >>>>  ' + str(tp_dest)  + '   with support: ' + str(sup_kk) + \
                            '   confidence: ' + str(tmp_conf) + '   liftrate: ' + str(tmp_lift) + '   origin: ' + str(tmp_sup) + '\n'
                            file_rule_readable.write(tmp_str)
                            
                            rule_list.append({tmp_lift : tmp_str } )
                              
                            #csv文件
                            tmp_str = '|'.join(itemset_kkk)  + ',' + '|'.join(tp_dest) + ',' + str(sup_kk) + \
                            ',' +  str(tmp_conf) + ',' + str(tmp_lift) + '\n'
                            file_rule_csv.write(tmp_str)
                            
                            
    file_rule_readable.close()
    file_rule_csv.close()
    
    file_rule_liftrate_desc = open(rules_ranked_desc_by_liftrate , 'w')
    rule_list.sort()
    for tmp_dict in rule_list :
        file_rule_liftrate_desc.write(tmp_dict[tmp_dict.keys()[0]])
    file_rule_liftrate_desc.close()
 
if __name__ == '__main__' :
    prepare_data()
    get_frequent_itemsets_1()
    
    #为循环做准备
    #-------------------------------------------------------------------------
    #计算distinct_item_cnt_now的初始值
    distinct_item_cnt_now = len(item_transaction_list_dict)
    #K的初始值为1
    k = 1
    #准备frequent_itemsets_verified
    for item in frequent_itemsets[1] :
        frequent_itemsets_verified[item] = frequent_itemsets[1][item]
    print ''
    print ''
    #-------------------------------------------------------------------------
    
    #start loop
    #-------------------------------------------------------------------------
    print 'Enter loop.'
    while True :
        #打印上一次所得
        print 'Totally ' + str(len(frequent_itemsets[k])) + ' candidates are verified.'
        print '##########################################'
        print ''
        
        #K自增
        k = k + 1
        
        print ''
        print '##########################################'
        print 'k = ' + str(k)
        
        #找出候选项,要尽可能少,如果没有找到,则结束
        if get_candidates(k) == -1 :
            print 'Get no candidates.'
            break
        print 'Totally get ' + str(len(frequent_itemsets_candidate)) + ' candidates.'
        
        #根据候选做一次数据过滤,使transaction_cnt_dict尽可能的小
        #如果项少了30%以上,则过滤数据
        get_distinct_item_in_candidate_itemsets()
        distinct_item_cnt_new = len(distinct_item_in_candidate_itemsets)
        print 'There are ' + str(distinct_item_cnt_now) + ' distinct items in transaction_cnt_dict.'
        print 'There are ' + str(distinct_item_cnt_new) + ' distinct items in candidates.'
        
        #if distinct_item_cnt_new < 0.7 * distinct_item_cnt_now :
        if True :
            #过滤掉transaction_cnt_dict中不再有用的项
            filter_data(k)
            distinct_item_cnt_now = distinct_item_cnt_new
            
        #检查候选项集
        if check_candidates_1() == -1 :
            print 'No candidates is frequent.'
            break
        
        #保存frequent_itemsets_verified中的频繁项集到frequent_itemsets
        save_frequent_itemsets(k)  
        #end loop
    print 'Exit loop.'
    #-------------------------------------------------------------------------
        
    #将得到的候频繁项集写入到文件中
    file = open(frequent_itemsets_save_file , 'w' )
    for tmp_k in frequent_itemsets :
        idx = 0
        for itemset in frequent_itemsets[tmp_k] :
            idx = idx + 1
            tmp_str = str(tmp_k) + ',' + str(idx) + ',' + '|'.join(itemset) + ',' + \
            str(frequent_itemsets[tmp_k][itemset]) + '\n'
            file.write(tmp_str)
    file.close()
    print 'All the frequent itemsets were saved in ' + frequent_itemsets_save_file + '.'
    
    get_rules()


你可能感兴趣的:(数据分析挖掘,算法实践)