本文采用Unet3d进行LiTS腹部CT肝脏肿瘤分割
数据集的train集合一共130个样例,都为nii格式,原始CT数据为volume-*.nii,分割的ground truth为segmentation-0.nii,其中0为背景,1为肝脏,2为肿瘤,但是并不是每个样例里边都含有肿瘤
本来是准备用https://blog.csdn.net/py184473894/article/details/88558886这里的keras实现的unet进行这个数据集的分割的,但是后来发现,不知道是我的代码问题,还是keras有bug,在计算肿瘤的dice的时候会计算出错,所以训练不出肿瘤的分割,我在keras的github上提了这个issue,但是还没人回复,如果有大佬可以解决的话,麻烦联系我,或者在issue下边回复
https://github.com/assassint2017/MICCAI-LITS2017/blob/master/data_prepare/get_fix_data.py
在这里使用了这个源代码进行,找到包含肝脏或者肿瘤的slice,然后上下取n片,作为训练集合
def fix_data(self):
upper = 200
lower = -200
expand_slice = 20 # 轴向上向外扩张的slice数量
size = 48 # 取样的slice数量
stride = 3 # 取样的步长
down_scale = 0.5
slice_thickness = 2
for ct_file in os.listdir(self.row_root_path + 'data/'):
print(ct_file)
# 将CT和金标准入读内存
ct = sitk.ReadImage(os.path.join(self.row_root_path + 'data/', ct_file), sitk.sitkInt16)
ct_array = sitk.GetArrayFromImage(ct)
seg = sitk.ReadImage(os.path.join(self.row_root_path + 'label/', ct_file.replace('volume', 'segmentation')),
sitk.sitkInt8)
seg_array = sitk.GetArrayFromImage(seg)
print(ct_array.shape, seg_array.shape)
# 将金标准中肝脏和肝肿瘤的标签融合为一个
seg_array[seg_array > 0] = 1
# 将灰度值在阈值之外的截断掉
ct_array[ct_array > upper] = upper
ct_array[ct_array < lower] = lower
# 找到肝脏区域开始和结束的slice,并各向外扩张
z = np.any(seg_array, axis=(1, 2))
start_slice, end_slice = np.where(z)[0][[0, -1]]
# 两个方向上各扩张个slice
if start_slice - expand_slice < 0:
start_slice = 0
else:
start_slice -= expand_slice
if end_slice + expand_slice >= seg_array.shape[0]:
end_slice = seg_array.shape[0] - 1
else:
end_slice += expand_slice
print(str(start_slice) + '--' + str(end_slice))
# 如果这时候剩下的slice数量不足size,直接放弃,这样的数据很少
if end_slice - start_slice + 1 < size:
print('!!!!!!!!!!!!!!!!')
print(ct_file, 'too little slice')
print('!!!!!!!!!!!!!!!!')
continue
ct_array = ct_array[start_slice:end_slice + 1, :, :]
seg_array = sitk.GetArrayFromImage(seg)
seg_array = seg_array[start_slice:end_slice + 1, :, :]
new_ct = sitk.GetImageFromArray(ct_array)
new_seg = sitk.GetImageFromArray(seg_array)
sitk.WriteImage(new_ct, os.path.join(self.data_root_path + 'data/', ct_file))
sitk.WriteImage(new_seg,
os.path.join(self.data_root_path + 'label/', ct_file.replace('volume', 'segmentation')))
基本上和源代码没有更改
首先是将130个数据随机分为训练集(0.8)和验证集(0.1)和测试集(0.1)
1、读取volume和segmentation
2、进行scale,将分辨率压缩
3、每个样例随机截取n个(depth,height,width)大小的3维块作为一个输入的batch
4、数据归一化到0-1
5、将读取函数包装为dataset、dataloader
使用的时候主要使用了以下函数
def next_train_batch_3d_sub_by_index(self, train_batch_size, crop_size, index,resize_scale=1):
train_imgs = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], 1])
train_labels = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], self.n_labels])
img, label = self.get_np_data_3d(self.train_name_list[index],resize_scale=resize_scale)
for i in range(train_batch_size):
sub_img, sub_label = util.random_crop_3d(img, label, crop_size)
sub_img = sub_img[:, :, :, np.newaxis]
sub_label_onehot = make_one_hot_3d(sub_label, self.n_labels)
train_imgs[i] = sub_img
train_labels[i] = sub_label_onehot
return train_imgs, train_labels
val集合类似
这里其实没什么好讲的,主要使用几个模块,resblock,seblock,RecombinationBlock、denseBlock等,然后上采样方式可以选是线性插值或者是deconv
class UNet(nn.Module):
def __init__(self, in_channels, filter_num_list, class_num, conv_block=RecombinationBlock, net_mode='2d'):
super(UNet, self).__init__()
if net_mode == '2d':
conv = nn.Conv2d
elif net_mode == '3d':
conv = nn.Conv3d
else:
conv = None
self.inc = conv(in_channels, 16, 1)
# down
self.down1 = Down(16, filter_num_list[0], conv_block=conv_block, net_mode=net_mode)
self.down2 = Down(filter_num_list[0], filter_num_list[1], conv_block=conv_block, net_mode=net_mode)
self.down3 = Down(filter_num_list[1], filter_num_list[2], conv_block=conv_block, net_mode=net_mode)
self.down4 = Down(filter_num_list[2], filter_num_list[3], conv_block=conv_block, net_mode=net_mode)
self.bridge = conv_block(filter_num_list[3], filter_num_list[4], net_mode=net_mode)
# up
self.up1 = Up(filter_num_list[4], filter_num_list[3], filter_num_list[3], conv_block=conv_block,
net_mode=net_mode)
self.up2 = Up(filter_num_list[3], filter_num_list[2], filter_num_list[2], conv_block=conv_block,
net_mode=net_mode)
self.up3 = Up(filter_num_list[2], filter_num_list[1], filter_num_list[1], conv_block=conv_block,
net_mode=net_mode)
self.up4 = Up(filter_num_list[1], filter_num_list[0], filter_num_list[0], conv_block=conv_block,
net_mode=net_mode)
self.class_conv = conv(filter_num_list[0], class_num, 1)
def forward(self, input):
x = input
x = self.inc(x)
conv1, x = self.down1(x)
conv2, x = self.down2(x)
conv3, x = self.down3(x)
conv4, x = self.down4(x)
x = self.bridge(x)
x = self.up1(x, conv4)
x = self.up2(x, conv3)
x = self.up3(x, conv2)
x = self.up4(x, conv1)
x = self.class_conv(x)
x = nn.Softmax(1)(x)
return x
之前因为是使用keras,记录loss和metrics太方便了,但是现在用了torch都要自己来写,然后想要使用tensorboard来实现,所以找到了https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/logger.py
用这个代码就可以将自己想要的loss和metrics记录下来
logger.scalar_summary('val_loss', val_loss, epoch)
logger.scalar_summary('val_dice0', val_dice0, epoch)
logger.scalar_summary('val_dice1', val_dice1, epoch)
logger.scalar_summary('val_dice2', val_dice2, epoch)
完整代码可以在我的github上找到,代码还在完善中,因为肿瘤的数量较少,所以对肿瘤的分割效果不太好,还在改进中
https://github.com/panxiaobai/lits_pytorch