相信大家对 Unet
论文中的 weight map
并不陌生。但是我翻遍 github
上的几个 Unet
经典实现(包括论文开源的 caffe
版本),都没有找到损失加权图的实现代码。我今天尝试着实现了一下,得到了与论文中几乎一致的结果,分享给大家。转载请在文首注明出处,尊重原创,谢谢。
作 者: 月牙眼的楼下小黑
联 系: zhanglf_tmac (WeChat)
声 明: 欢迎转载本文中的图片或文字,请说明出处
1. 实验图片的获取和加载
我没有论文中的数据集,我只是简单地从 pdf
中截取了示例图,并用 win10
自带的画图软件另存为单通道 bmp
文件。
实验图片
我习惯用 skimage
库的 io
函数加载图片:
from skimage import io
import matplotlib.pyplot as plt
gt = io.imread('bin_img.bmp')
gt = 1 * (gt >0)
plt.figure(figsize = (10,10))
plt.imshow(gt)
plt.show()
实验图的加载
2. 获得 class weight map
根据 unet
论文, class_weight_map
的作用是: to balance the class frequencies
。 我们分别统计前景(即细胞)和 背景的像素数目 , 取倒数并归一化作为 class map
。 这只是一种简单的做法,读者亦可设计更复杂的类别平衡权重计算机制。
import numpy as np
# 【1】计算细胞和背景的像素频率
c_weights = np.zeros(2)
c_weights[0] = 1.0 / ((gt == 0).sum())
c_weights[1] = 1.0 / ((gt == 1).sum())
# 【2】归一化
c_weights /= c_weights.max()
# 【3】得到 class_weight map(cw_map)
cw_map = np.where(gt==0, c_weights[0], c_weights[1])
plt.figure(figsize = (10,10))
plt.imshow(cw_map)
plt.show()
class_weight_map
3. 连通域分析
对每个实例(即每个细胞),我们需要获取只包含该实例的掩码。例如,如果一幅图像中有 2
个细胞: 细胞 a
, 细胞 b
, 我们需要获得只包含细胞 a
的二值掩码图 msk1
, 以及只包含细胞 b
的二值掩码图 msk2
。所以我们需要对原图像进行连通域分析。
我们调用 skimage.measure
下的 label
函数实现二值图像的连通区域标记:
skimage.measure.label(image, connectivity= None)
: connectivity=1
代表4
邻接, 2
代表 8
邻接
该函数会返回跟 image
同等大小的标记矩阵,不同值的数字代表不同的连通域。
from skimage import measure,color
#【4】连通域分析,并彩色化
cells = measure.label(gt, connectivity=2)
cells_color = color.label2rgb(cells , bg_label = 0, bg_color = (0, 0, 0))
plt.figure(figsize = (20,20))
plt.imshow(cells_color)
plt.show()
连通域分析并彩色化
4. 得到 distance_weight_map
解释代码是一件非常困难的事,有必要把公式再贴一遍,请读者对照代码和公式理解:
对一幅图像,我们可以得到一个二值掩码集合 :
{ msk1, msk2, ...mski,...mskn }
每个掩码只包含一个细胞 。 我们对每个二值掩码进行 距离变化, 使用 opencv
中的 distanceTransform
函数即可轻易实现。我们得到一个距离变换后的浮点图集合:
{dist1, dist2, ..., disti,...distn }, 其中 disti = distanceTransform(mski)。
我们要求的 d1_map, 即公式中的 d1(x), 在 (i,j)
处的值为 :
d1_map(i,j) = min (dist1(i,j), dist2(i,j), ...distn(i,j))
笔者不知道求一个集合中 次最小 的数学运算符是什么,暂且用 mins 代替吧, 即对一个集合 a
, mins(a) = min (a - {min(a)})。
那么 d2_map, 即公式中的 d2(x), 在 (i,j)
处的值为 :
d2_map(i,j) = mins (dist1(i,j), dist2(i,j), ...distn(i,j))
得到 d1_map 和 d2_map, 我们不难获得对应公式中的第二项的 distance_weight_map
。
import cv2
#【5】计算得到 distance weight map (dw_map)
w0 = 10
sigma = 5
dw_map = np.zeros_like(gt)
maps = np.zeros((gt.shape[0], gt.shape[1], cells.max()))
if cells.max()>=2:
for i in range(1, cells.max() + 1):
maps[:,:,i-1] = cv2.distanceTransform(1- (cells == i ).astype(np.uint8), cv2.DIST_L2, 3)
maps = np.sort(maps, axis = 2)
d1 = maps[:,:,0]
d2 = maps[:,:,1]
dis = ((d1 + d2)**2) / (2 * sigma * sigma)
dw_map = w0*np.exp(-dis) * (cells == 0)
plt.figure(figsize = (10,10))
plt.imshow(dw_map, cmap = 'jet')
plt.show()