匈牙利算法指派问题的python实现 & 使用python计算聚类精度

匈牙利算法的python实现

了解匈牙利算法的内容和其对偶问题的理解:匈牙利算法与对偶观点

简单描述匈牙利算法

具体描述见匈牙利算法与对偶观点
一个原始的指派问题:有n个工人,和n个需要作业的地点。需要为每个工人安排一个工作的地点,记变量 x i j = 0   o r   1 x_{ij}=0\ or\ 1 xij=0 or 1,代表派第i个工人去第j个地点的情况(0代表不指派,1代表指派)。同时将第i个工人派去第j个地点存在一定的开销,记录为 c i j c_{ij} cij。公司要求最小化开销。
c i j c_{ij} cij写成矩阵 C C C

  1. C C C的每一行减去该行的最小值。
  2. 第一步做差得到的矩阵再减去每列的最小值,得到矩阵 R R R
  3. 找到横线或者竖线用来覆盖 R R R中的0元素,并且使得用的线的数量最少
  4. 如果使用的线条数量刚好是n条(例子中n=4),那么已经求解得到一个最优解。只要在 R R R中选n个0元素,并且这n个元素互相不同行不同列,再使这n个元素对应 x i j x_{ij} xij为1,其余 x i j x_ij xij为0。
  5. 若线条的数量小于n条,那么就选择线条没覆盖的数字中的最小值 δ \delta δ,使没被覆盖的数字所在行中的 R R R的元素全部减去 δ \delta δ,再使被线条覆盖的列元素全部加上 δ \delta δ。然后返回第三步。

Python实现匈牙利算法

  • python实现中首先实现第1,2步骤。第3步到第5步进入迭代过程,每次迭代都需要找到最小的划线数量用于覆盖0元素,直到找到满足条件的n条线可以覆盖所有0元素。
  • 采用贪婪算法寻找最小的划线数量,统计每行每列的0元素个数,优先划去含0元素最多的行和列,并更新其余行和列剩余0元素的个数。迭代直到所有0元素被划去。
  • 找到n条划线方法之后需要从表上挑出n个0元素构成最终的指派解。通过深度递归搜索实现,第i层递归搜索在矩阵的第i行选择一个0元素。
  • 具体见代码注释。
import numpy as np
import copy
def  hungarian(w):
    """"
    传入矩阵w,w_ij为将第i人指派到第j个工作的代价
    返回使得代价最小的分配方法
    """"
    
    assert w.shape[0]==w.shape[1]
    m = w.shape[0]
    w = w - w.min(axis=1).reshape(-1,1)#每行减去最小值
    w = w - w.min(axis=0).reshape(1,-1)#每列减去最小值
    while True:
        picked_axis0 = []#记录划横线的位置
        picked_axis1 = []#记录划竖线的位置
        zerocnt = np.concatenate([(w==0).sum(axis=1),(w==0).sum(axis=0)],axis=0)
        #记录每行每列的0元素个数
        
        while zerocnt.max()>0:#如果所有0都被划去,终止循环
            
            maxindex = zerocnt.argmax()#找出拥有0最多的行或者列
            if maxindex<m:#0最多的是某一行
                picked_axis0.append(maxindex)#记录该行
                zerocnt[np.argwhere(w[maxindex,:]==0).squeeze(1)+m]= \
                   np.maximum(zerocnt[np.argwhere(w[maxindex,:]==0).squeeze(1)+m]-1,0)
                #更新其余列剩余0元素的个数
                zerocnt[maxindex]=0#该行被划去不剩余0元素
            else:#0最多的是某一列
                picked_axis1.append(maxindex-m)#记录该列
                zerocnt[np.argwhere(w[:,maxindex-m]==0).squeeze(1)]= \
                    np.maximum(zerocnt[np.argwhere(w[:,maxindex-m]==0).squeeze(1)]-1,0)
                #更新其余行剩余0元素的个数
                zerocnt[maxindex]=0#该列被划去不剩余0元素
        if len(picked_axis0)+len(picked_axis1)<m:#如果划线数量不足,更新w矩阵,进入下一步循环
            left_axis0 = list(set(list(range(m)))-set(list(picked_axis0)))
            left_axis1 = list(set(list(range(m)))-set(list(picked_axis1)))
            delta = w[left_axis0,:][:,left_axis1].min()
            w[left_axis0,:]-=delta
            w[:,picked_axis1]+=delta
        else:#划线数量满足条件,跳出循环
            break
    #找出合适的0元素的位置
    pos = []#按行记录每行哪些列是0元素
    for i in range(m):
        pos.append(list(np.argwhere(w[i,:]==0).squeeze(1)))
   
    #深度递归搜索函数,找到一个0的分配方案,让每行每列仅选出一个0
    def search(layer,path):
        if len(path) == m:
            return path
        else:
            for i in pos[layer]:
                if i not in path:
                    newpath = copy.deepcopy(path)
                    newpath.append(i)
                    ans = search(layer+1,newpath)
                    if ans is not None:
                        return ans
            return None
    #调用深度递归搜索  
    path = search(0,[])
    
    return list(zip(range(m),path))

调用举例

w = np.array(
    [[12,24,5],
    [23,21,15],
    [17,19,13]]
)
print(hungarian(w))

输出结果是:

[(0, 2), (1, 1), (2, 0)]

使用匈牙利算法计算聚类精度

有n个数据属于C个类别,某聚类算法将数据分为C类,求解该聚类算法的聚类精度

聚类精度=聚类正确的样本数量/总样本数量

例如12个样本的真实类别是:
truth = [0,0,0,1,1,1,2,2,2,3,3,3]
某聚类算法给出的聚类结果是:
pred=[1,1,1,1,3,3,2,2,2,2,0,0]

为了计算聚类精度,不能直接将pred和truth进行比较,必须先匹配聚类的类别和样本的类别,当聚类的第1类对应真实的第0类,聚类的第3类匹配真实的第1类…,可以达到二者的最大匹配,此时计算到聚类精度为10/12=0.833

对上述匈牙利算法稍作修改,即可使用python计算聚类精度:

truth = np.array([0,0,0,1,1,1,2,2,2,3,3,3])
pred=np.array([1,1,1,1,3,3,2,2,2,2,0,0])
print(hungarian_cluster_acc(truth,pred))

输出是

0.8333333333333334

计算聚类精度代码如下:

def  hungarian_cluster_acc(x,y):
    assert x.shape==y.shape
    assert x.min()==0
    assert y.min()==0
    
    m = 1 + max(x.max(),y.max())
    n = len(x)
    total = np.zeros([m,m])
    for i in range(n):
        total[x[i],int(y[i])]+=1
    w = total.max()-total
    w = w - w.min(axis=1).reshape(-1,1)
    w = w - w.min(axis=0).reshape(1,-1)
    while True:
        picked_axis0 = []
        picked_axis1 = []
        zerocnt = np.concatenate([(w==0).sum(axis=1),(w==0).sum(axis=0)],axis=0)
        
        while zerocnt.max()>0:
            
            maxindex = zerocnt.argmax()
            if maxindex<m:
                picked_axis0.append(maxindex)
                zerocnt[np.argwhere(w[maxindex,:]==0).squeeze(1)+m]= \
                   np.maximum(zerocnt[np.argwhere(w[maxindex,:]==0).squeeze(1)+m]-1,0)
                zerocnt[maxindex]=0
            else:
                picked_axis1.append(maxindex-m)
                zerocnt[np.argwhere(w[:,maxindex-m]==0).squeeze(1)]= \
                    np.maximum(zerocnt[np.argwhere(w[:,maxindex-m]==0).squeeze(1)]-1,0)
                zerocnt[maxindex]=0
        if len(picked_axis0)+len(picked_axis1)<m:
            left_axis0 = list(set(list(range(m)))-set(list(picked_axis0)))
            left_axis1 = list(set(list(range(m)))-set(list(picked_axis1)))
            delta = w[left_axis0,:][:,left_axis1].min()
            w[left_axis0,:]-=delta
            w[:,picked_axis1]+=delta
        else:
            break
    pos = []
    for i in range(m):
        pos.append(list(np.argwhere(w[i,:]==0).squeeze(1)))
    
    def search(layer,path):
        if len(path) == m:
            return path
        else:
            for i in pos[layer]:
                if i not in path:
                    newpath = copy.deepcopy(path)
                    newpath.append(i)
                    ans = search(layer+1,newpath)
                    if ans is not None:
                        return ans
            return None
        
    path = search(0,[])
    totalcorrect = 0
    for i,j in enumerate(path):
        totalcorrect += total[i,j]
    return totalcorrect/n
        

你可能感兴趣的:(算法,python,聚类)