语义分割中的mIoU计算函数解读

代码来自于CVPR2018的一篇文章Context Encoding for Semantic Segmentation。
github工程地址为:https://github.com/zhanghang1989/PyTorch-Encoding
很棒的工作。
我是新入门,对代码解读如下,有不对的地方请高手们批评指正。
我的标签图像序号是从0开始,依次编号,0代表背景类

def batch_intersection_union(predict, target, nclass):
    """Batch Intersection of Union
    Args:
        predict: input 4D tensor 具体地为B(batch大小)*C(通道)*H*W
        target: label 3D tensor 具体地为B(batch大小)*H*W
        nclass: number of categories (int)
    """
    #在通道维上取最大值,注意predict为第二个返回值,因此是索引,从0开始
    #此时predict和target一样,维度均为B*H*W,且值均为0,1,2.........
    _, predict = torch.max(predict, 1) 
    mini = 1
    maxi = nclass
    nbins = nclass
    #将predict和target放入cpu并转换为numpy数组,同时+1
    #此时predict和target的值为1,2,......
    predict = predict.cpu().numpy() + 1
    target = target.cpu().numpy() + 1

   #假如我除去背景类只有1类目标,则nbins为2。
   #此句似乎没有实际意义?predict值不变,因为target值均大于0
   # (target > 0)返回的均为true
    predict = predict * (target > 0).astype(predict.dtype)
    #求交集intersection,维度B*H*W,包含背景类0的交集
    #intersection交集处>0,非交集处为0(其内部像素值包括0,1,2)
    intersection = predict * (predict == target)
    # areas of intersection and union
    #绘制直方图,nbins个区间,range=(mini, maxi)左闭右开
    #假如我除去背景类只有1类目标,则nbins为2,range=(1,2),则表示将数组均匀地分为2个区间:[1,1.5],[1.5,2]
    #第一个bins代表背景,第二个bins代表目标
    area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
    area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
    area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
    area_union = area_pred + area_lab - area_inter
    #交集一定小于并集
    assert (area_inter <= area_union).all(), \
        "Intersection area should be smaller than Union area"
    return area_inter, area_union

关于histogram函数,做如下详细说明,参考链接:https://blog.csdn.net/pipisorry/article/details/48770785
https://blog.csdn.net/hyqsong/article/details/40514879
histogram(a,bins=10,range=None,normed=False,weights=None)
其中,a是保存待统计数据的数组,bins指定统计的区间个数,即对统计范围的等分数。 range是一个长度为2的元组,表示统计范围的最小值和最大值,默认值为None,表示范围由 数据的范围决定,即(a.min(), a.max())。
ranges 对于均匀直方图(即nuiform为true),ranges是一个由dims个float数对构成的数组,数对表示对应维的bin的范围. eg.某维有N==2个bins,在ranges中对应的数对为{0,10},均匀的意思是讲,将该维的bin均匀的分为2个区间:[0,5]和[5,10],这是程序自动划分的,只需提供给他数对表示最值范围即可。

当normed参数为False时,函数返回数组a中的数据在每个区间的个数,否则对个数进行正规化处理,使它等于每个区间的概宇密度。weights参数类似。

NumPy中histogram函数应用到一个数组返回一对变量:直方图数组hist和箱式向量,即两个一维数组–hist和bin_edges,第一个数组是每个区间(一个区间代表一个bins)的统计结果, 第二个数组长度为len(hist)+1,每两个相邻的数值构成一个统计区间。
对于bin_edges : array of dtype float,bin edges 的长度要是 hist 的长度加1,bin edges (length(hist)+1),也即 (bin_edges[0], bin_edges[1]) ⇒ hist[0],….,(bin_edges[-2], bin_edges[-1]) ⇒ hist[-1],bin_edges 参数值与输入参数的(bins+1) 保持一致

你可能感兴趣的:(语义分割,mIoU)