详细代码
import numpy as np
def compute_miou(pred, target, nclass):
mini = 1
# 计算公共区域
intersection = pred * (pred == target)
# 直方图
area_inter, _ = np.histogram(intersection, bins=2, range=(mini, nclass))
area_pred, _ = np.histogram(pred, bins=2, range=(mini, nclass))
area_target, _ = np.histogram(target, bins=2, range=(mini, nclass))
area_union = area_pred + area_target - area_inter
# 交集已经小于并集
assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area"
rate = round(max(area_inter) / max(area_union), 4)
return rate
if __name__ == '__main__':
nclass = 1
# target
target = np.zeros(shape=(200, 200))
target[0:100, 0:100] = 1
# pred
pred = np.zeros(shape=(200, 200))
pred[10:110, 10:110] = 1
# 计算miou
rate = compute_miou(pred, target, nclass)
print(rate)