语义分割生成混淆矩阵

    def update(self, a, b):
        n = self.num_classes  
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            k = (a >= 0) & (a < n)  # 0,1
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

这是语义分割生成混淆矩阵的代码,详细流程:

1、传入参数a为 target.flatten(),b为 output.argmax(1).flatten(),b为模型预测值,找到概率最大的索引值作为最终预测值,所以会用到.argmax(1)

2、初始化混淆矩阵,大小为n×n

3、k返回的是一个布尔类型的一维向量,超过或小于类别范围返回False

4、根据k中False的索引值,将真实标签对应位置元素丢掉,即a[k]的长度可能会小于a,并且根据有意义的真实标签索引值找到对应预测值进行计算,乘以n相当于flatten的逆过程,方便reshape之后变成一个n×n矩阵

语义分割生成混淆矩阵_第1张图片

你可能感兴趣的:(pytorch,python,计算机视觉)