0. 前言
1. Apriori 算法寻找频繁项集
2. 从频繁项集中挖掘关联规则
3. 实战案例
3.1. apriori算法发现频繁项集和关联规则
学习完机器学习实战的Apriori,简单的做个笔记。文中部分描述属于个人消化后的理解,仅供参考。
所有代码和数据可以访问 我的 github
如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~
从大规模的数据集中,寻找不同特征或者物品之间的隐含关系,称为关联分析(association analysis),或者关联规则学习(association rule learning)。
给出以下例子:
交易号码 | 商品 |
---|---|
0 | 豆奶、莴苣 |
1 | 莴苣、尿布、葡萄酒、甜菜 |
2 | 豆奶、尿布、葡萄酒、橙汁 |
3 | 莴苣、豆奶、尿布、葡萄酒 |
4 | 莴苣、豆奶、尿布、橙汁 |
在关联分析中,有下列概念:
关联分析通常由两步组成,寻找数据中的频繁项集,然后从频繁项集中挖掘关联规则。
例如一个数据集中,共有 种物品,则可能组成的频繁项集,如下图所示(图源:机器学习实战):
种物品,则有
种可能的频繁项集,则需要遍历
次数据,
种物品,则有
种可能的频繁项集,则需要遍历
次数据。计算量是十分庞大的。
Apriori 算法给出的原理表示,如果某个项是频繁的,那么这个项的所有子集也是频繁的。
这条定理的逆否命题表示为,如果某个项是非频繁的,那么这个项的所有超集也是非频繁的。
如下图所示(图源:机器学习实战),如果 是非频繁的,那么
都是非频繁的:
通过 Apriori 算法,可以在计算出 的支持度的时候,不计算
,降低计算量。
Apriori 算法寻找频繁项集的流程可如下表示:
规则 ,
称为前件,
称为后件。
对于每一个频繁项,都可挖掘出许多对关联规则,如下图所示(图源:机器学习实战),频繁项 :
如果某条规则不满足最小可信度的要求,则该规则的子集都不满足最小可信度的要求,即包含该后件的规则,都不满足最小可信度要求。
例如规则 不满足最小可信度,则后件中包含
的均不满足最小可信度要求。
挖掘关联规则的流程如下表示:
以下将展示书中案例的代码段,所有代码和数据可以在github中下载:
# coding:utf-8
from numpy import *
"""
apriori算法发现频繁项集和关联规则
"""
# 加载数据集
def loadDataSet():
return [[1, 3, 4], [2, 3, 5], [1, 2, 3, 5], [2, 5]]
# 创建初始只包含单个元素的数据项集
def createC1(dataSet):
C1 = []
for transaction in dataSet:
for item in transaction:
if not [item] in C1:
C1.append([item])
C1.sort()
return list(map(frozenset, C1))
# 扫描数据集
# 计算当前数据项集中满足最小支持度的项集
def scanD(D, Ck, minSupport):
ssCnt = {}
# 遍历每一条数据
for tid in D:
# 遍历每一个项集
for can in Ck:
# 项集是数据的一个子集
if can.issubset(tid):
if can not in ssCnt:
ssCnt[can] = 1
else:
ssCnt[can] += 1
numItems = float(len(D))
retList = []
supportData = {}
# 遍历每一个项集,计算支持度
for key in ssCnt:
support = ssCnt[key] / numItems
if support >= minSupport:
retList.append(key)
supportData[key] = support
return retList, supportData
# 根据满足最小支持度的项集
# 计算项集的组合
def aprioriGen(Lk, k): # creates Ck
retList = []
lenLk = len(Lk)
for i in range(lenLk):
for j in range(i + 1, lenLk):
# 当0~k-2个项相同的时候
# 合并可以得到长度为k的项,且不会重复
L1 = list(Lk[i])[:k - 2]
L2 = list(Lk[j])[:k - 2]
L1.sort()
L2.sort()
if L1 == L2:
retList.append(Lk[i] | Lk[j])
return retList
# apriori算法,生成频繁项集
# 此处并没有不计算那些项集为非频繁的超集,依旧按照原始计算
def apriori(dataSet, minSupport=0.5):
# 长度为1的项集
C1 = createC1(dataSet)
# 数据集D
D = list(map(set, dataSet))
# 满足最小支持度的项集l1,支持度supportData
L1, supportData = scanD(D, C1, minSupport)
# 构建列表
L = [L1]
k = 2
# 当当前满足支持度的项集个数大于0时,继续计算
while (len(L[k - 2]) > 0):
# 计算当前满足支持度的项集的组合
Ck = aprioriGen(L[k - 2], k)
# 计算组合后,满足最小支持度的项集和支持度
Lk, supK = scanD(D, Ck, minSupport)
supportData.update(supK)
# 将满足最小支持度的项集添加进L
L.append(Lk)
k += 1
return L, supportData
# 规则分析
def generateRules(L, supportData, minConf=0.7):
bigRuleList = []
# 只判断项集中元素大于1的情况,因对其进行拆分
for i in range(1, len(L)):
# 当前每个项集的长度为i+1,遍历每个项集
for freqSet in L[i]:
# freqSet: frozenset({1, 3})
# H1: [frozenset({1}), frozenset({3})]
H1 = [frozenset([item]) for item in freqSet]
# 项集长度大于2
if (i > 1):
# 后件的长度为1,返回大于最小可信度的hmp1
# 运用大于最小可行度的后件,再进行组合,生成更长的后件,分级判断
#
######################################################
# 这部分与书中不同,本人认为书中有错误,缺少下面第一二行 #
# 按照书中,缺少判断 前件长度>1且后件长度=1的情况 #
# 希望可以得到指正 #
######################################################
#
Hmp1 = calcConf(freqSet, H1, supportData, bigRuleList, minConf)
if (len(Hmp1) > 1):
rulesFromConseq(freqSet, Hmp1, supportData, bigRuleList, minConf)
else:
# 项集长度只有2
calcConf(freqSet, H1, supportData, bigRuleList, minConf)
return bigRuleList
# 计算以H中的一个为前件,一个为后件的可信度
def calcConf(freqSet, H, supportData, brl, minConf=0.7):
prunedH = []
for conseq in H:
# 计算可信度
conf = supportData[freqSet] / supportData[freqSet - conseq]
if conf >= minConf:
print(freqSet - conseq, '-->', conseq, 'conf:', conf)
brl.append((freqSet - conseq, conseq, conf))
prunedH.append(conseq)
return prunedH
# 分级计算长度大于2的项集的规则
def rulesFromConseq(freqSet, H, supportData, brl, minConf=0.7):
# 单个频繁项集的长度
m = len(H[0])
if (len(freqSet) > (m + 1)):
# 原先长度为m,生成长度为m+1的
# H: [frozenset({2}), frozenset({3}), frozenset({5})]
# Hmp1: [frozenset({2, 3}), frozenset({2, 5}), frozenset({3, 5})]
Hmp1 = aprioriGen(H, m + 1)
# 将Hmp1中的每一个作为后件,计算可信度
# 返回大于最小可信度的后件,用作下一次调用此函数时,组合成新的后件
Hmp1 = calcConf(freqSet, Hmp1, supportData, brl, minConf)
if (len(Hmp1) > 1):
# 还可进一步扩大后件的长度
rulesFromConseq(freqSet, Hmp1, supportData, brl, minConf)
if __name__ == '__main__':
# dataSet = loadDataSet()
# L, suppData = apriori(dataSet)
# rules = generateRules(L, suppData)
mushDatSet = [line.split() for line in open('mushroom.dat').readlines()]
L, suppData = apriori(mushDatSet, minSupport=0.3)
for item in L[1] + L[2]:
if item.intersection('2'):
print(item)
如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~