【python】计算miou

详细代码

import numpy as np

def compute_miou(pred, target, nclass):
    mini = 1

    # 计算公共区域
    intersection = pred * (pred == target)

    # 直方图
    area_inter, _ = np.histogram(intersection, bins=2, range=(mini, nclass))
    area_pred, _ = np.histogram(pred, bins=2, range=(mini, nclass))
    area_target, _ = np.histogram(target, bins=2, range=(mini, nclass))
    area_union = area_pred + area_target - area_inter

    # 交集已经小于并集
    assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area"

    rate = round(max(area_inter) / max(area_union), 4)
    return rate

if __name__ == '__main__':
    nclass = 1
    # target
    target = np.zeros(shape=(200, 200))
    target[0:100, 0:100] = 1

    # pred
    pred = np.zeros(shape=(200, 200))
    pred[10:110, 10:110] = 1

    # 计算miou
    rate = compute_miou(pred, target, nclass)
    print(rate)

你可能感兴趣的:(pytorch,python)