python 复现 Unet 论文中的 Weight Map

 

相信大家对 Unet 论文中的 weight map 并不陌生。但是我翻遍 github上的几个 Unet 经典实现(包括论文开源的 caffe 版本),都没有找到损失加权图的实现代码。我今天尝试着实现了一下,得到了与论文中几乎一致的结果,分享给大家。转载请在文首注明出处,尊重原创,谢谢

 

作 者: 月牙眼的楼下小黑
联 系: zhanglf_tmac (WeChat)
声 明: 欢迎转载本文中的图片或文字,请说明出处


1. 实验图片的获取和加载

我没有论文中的数据集,我只是简单地从 pdf 中截取了示例图,并用 win10 自带的画图软件另存为单通道 bmp 文件。

实验图片

 

我习惯用 skimage 库的 io 函数加载图片:

from skimage import io
import matplotlib.pyplot as plt

gt = io.imread('bin_img.bmp')
gt = 1 * (gt >0)
plt.figure(figsize = (10,10))
plt.imshow(gt)
plt.show()

实验图的加载

2. 获得 class weight map

根据 unet 论文, class_weight_map 的作用是: to balance the class frequencies。 我们分别统计前景(即细胞)和 背景的像素数目 , 取倒数并归一化作为 class map。 这只是一种简单的做法,读者亦可设计更复杂的类别平衡权重计算机制。

import numpy as np

# 【1】计算细胞和背景的像素频率
c_weights = np.zeros(2)
c_weights[0] = 1.0 / ((gt == 0).sum())
c_weights[1] = 1.0 / ((gt == 1).sum())

# 【2】归一化
c_weights /= c_weights.max()

# 【3】得到 class_weight map(cw_map)
cw_map = np.where(gt==0, c_weights[0], c_weights[1])
plt.figure(figsize = (10,10))
plt.imshow(cw_map)
plt.show()

class_weight_map

3. 连通域分析

对每个实例(即每个细胞),我们需要获取只包含该实例的掩码。例如,如果一幅图像中有 2个细胞: 细胞 a, 细胞 b, 我们需要获得只包含细胞 a 的二值掩码图 msk1, 以及只包含细胞 b 的二值掩码图 msk2。所以我们需要对原图像进行连通域分析。

我们调用 skimage.measure下的 label函数实现二值图像的连通区域标记:

skimage.measure.label(image, connectivity= None) : connectivity=1代表4邻接, 2代表 8邻接

该函数会返回跟 image同等大小的标记矩阵,不同值的数字代表不同的连通域。

from skimage import measure,color

#【4】连通域分析,并彩色化
cells = measure.label(gt, connectivity=2)
cells_color = color.label2rgb(cells , bg_label = 0,  bg_color = (0, 0, 0)) 
plt.figure(figsize = (20,20))
plt.imshow(cells_color)
plt.show()

连通域分析并彩色化

4. 得到 distance_weight_map

解释代码是一件非常困难的事,有必要把公式再贴一遍,请读者对照代码和公式理解:

 

对一幅图像,我们可以得到一个二值掩码集合 :

{ msk1, msk2, ...mski,...mskn }

每个掩码只包含一个细胞 。 我们对每个二值掩码进行 距离变化, 使用 opencv 中的 distanceTransform 函数即可轻易实现。我们得到一个距离变换后的浮点图集合:

{dist1, dist2, ..., disti,...distn }, 其中 disti = distanceTransform(mski)

我们要求的 d1_map, 即公式中的 d1(x), 在 (i,j) 处的值为 :

d1_map(i,j) = min (dist1(i,j), dist2(i,j), ...distn(i,j))

笔者不知道求一个集合中 次最小 的数学运算符是什么,暂且用 mins 代替吧, 即对一个集合 amins(a) = min (a - {min(a)})

那么 d2_map, 即公式中的 d2(x), 在 (i,j) 处的值为 :

d2_map(i,j) = mins (dist1(i,j), dist2(i,j), ...distn(i,j))

得到 d1_map 和 d2_map, 我们不难获得对应公式中的第二项的 distance_weight_map

import cv2

#【5】计算得到 distance weight map (dw_map)
w0 = 10
sigma = 5
dw_map = np.zeros_like(gt)
maps = np.zeros((gt.shape[0], gt.shape[1], cells.max()))
if cells.max()>=2:
    for i in range(1, cells.max() + 1):
        maps[:,:,i-1] =  cv2.distanceTransform(1- (cells == i ).astype(np.uint8), cv2.DIST_L2, 3)
    maps = np.sort(maps, axis = 2)
    d1 = maps[:,:,0]
    d2 = maps[:,:,1]
    dis = ((d1 + d2)**2) / (2 * sigma * sigma)
    dw_map = w0*np.exp(-dis) * (cells == 0)
plt.figure(figsize = (10,10))
plt.imshow(dw_map, cmap = 'jet')
plt.show()

你可能感兴趣的:(图像特征,机器学习)