FCN源码解读之score.py

转载自 https://blog.csdn.net/qq_21368481/article/details/80424754

score.py是FCN中用于测试测试集/验证集的,并输出相应的像素准确度、平均准确度、mean IU和频率加权交并比(frequency weighted IU)四个指标的python文件。

score.py的源码如下:


   
     
     
     
     
  1. from __future__ import division
  2. import caffe
  3. import numpy as np
  4. import os
  5. import sys
  6. from datetime import datetime
  7. from PIL import Image
  8. def fast_hist(a, b, n):
  9. k = (a >= 0) & (a < n)
  10. return np.bincount(n * a[k].astype(int) + b[k], minlength=n** 2).reshape(n, n)
  11. def compute_hist(net, save_dir, dataset, layer='score', gt='label'):
  12. n_cl = net.blobs[layer].channels
  13. if save_dir:
  14. os.mkdir(save_dir)
  15. hist = np.zeros((n_cl, n_cl))
  16. loss = 0
  17. for idx in dataset:
  18. net.forward()
  19. hist += fast_hist(net.blobs[gt].data[ 0, 0].flatten(),
  20. net.blobs[layer].data[ 0].argmax( 0).flatten(),
  21. n_cl)
  22. if save_dir:
  23. im = Image.fromarray(net.blobs[layer].data[ 0].argmax( 0).astype(np.uint8), mode= 'P')
  24. im.save(os.path.join(save_dir, idx + '.png'))
  25. # compute the loss as well
  26. loss += net.blobs[ 'loss'].data.flat[ 0]
  27. return hist, loss / len(dataset)
  28. def seg_tests(solver, save_format, dataset, layer='score', gt='label'):
  29. print '>>>', datetime.now(), 'Begin seg tests'
  30. solver.test_nets[ 0].share_with(solver.net)
  31. do_seg_tests(solver.test_nets[ 0], solver. iter, save_format, dataset, layer, gt)
  32. def do_seg_tests(net, iter, save_format, dataset, layer='score', gt='label'):
  33. n_cl = net.blobs[layer].channels
  34. if save_format:
  35. save_format = save_format.format(iter)
  36. hist, loss = compute_hist(net, save_format, dataset, layer, gt)
  37. # mean loss
  38. print '>>>', datetime.now(), 'Iteration', iter, 'loss', loss
  39. # overall accuracy
  40. acc = np.diag(hist).sum() / hist.sum()
  41. print '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc
  42. # per-class accuracy
  43. acc = np.diag(hist) / hist.sum( 1)
  44. print '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc)
  45. # per-class IU
  46. iu = np.diag(hist) / (hist.sum( 1) + hist.sum( 0) - np.diag(hist))
  47. print '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu)
  48. freq = hist.sum( 1) / hist.sum()
  49. print '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \
  50. (freq[freq > 0] * iu[freq > 0]).sum()
  51. return hist

详细解读如下:

(1)fast_hist()函数


   
     
     
     
     
  1. '''
  2. 产生n×n的分类统计表
  3. 参数a:标签图(转换为一行输入),即真实的标签
  4. 参数b:score层输出的预测图(转换为一行输入),即预测的标签
  5. 参数n:类别数
  6. '''
  7. def fast_hist(a, b, n):
  8. #k为掩膜(去除了255这些点(即标签图中的白色的轮廓),其中的a>=0是为了防止bincount()函数出错)
  9. k = (a >= 0) & (a < n)
  10. #bincount()函数用于统计数组内每个非负整数的个数
  11. #详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.bincount.html
  12. return np.bincount(n * a[k].astype(int) + b[k], minlength=n** 2).reshape(n, n)

此函数用于产生n*n的分类统计表,还不理解的可以看如下分析:

假如输入的标签图a是3*3的,如下左图,图中的数字表示该像素点的归属,即每个像素点所属的类别(其中n=3,即共有三种类别);预测标签图b的大小和a相同,如右图所示(图中的数字也代表每个像素点的类别归属)。

             FCN源码解读之score.py_第1张图片      FCN源码解读之score.py_第2张图片

直观上看,b中预测的标签有两个像素点预测出错,即

b01,b20b01,b20

表示属于第i类的所有像素数目

对应do_seg_tests()函数中的源码,相信大家肯定能更好理解和掌握这四个指标的计算技巧。

其中还有一点,就是交并比IU,为所有真实属于第i类的像素点所组成的集合A与所有预测属于第i类的像素点所组成的集合B的交集和并集之比,如下图

FCN源码解读之score.py_第3张图片

你可能感兴趣的:(Deep,Learning,FCN,caffe)