Pytorch-UNet医学影像分割处理-踩坑及经验

目录

  • 基本情况介绍
  • 模型结构
  • 预处理
  • 代价函数选择
  • 优化器选择
  • 训练参数调整

基本情况介绍

下图为某病人心脏切片,希望在图上分割出其四个心室。
Pytorch-UNet医学影像分割处理-踩坑及经验_第1张图片
标签如下图:
Pytorch-UNet医学影像分割处理-踩坑及经验_第2张图片
总共9个心脏,平均每个18张切片,图片大小512512,人工裁剪到了272272来聚焦中心部位。

模型结构

  1. 我是直接用了github上别人写的模型,在使用时需要留意通道数;各层输出的shape;自己图片的size需不需要padding,需不需要调整kernel size。
  2. 用3D UNet的时候建议清下不用的变量,不然很容易爆显存。
  3. 这里整体上我选择two stage的思想,先把心脏的外轮廓分割出来,再在分割出的外轮廓内分割四个心室。两个stage都用UNet实现。

预处理

  1. 关于图像的mode,一开始是用RGB形式导入,输入模型的shape为3272272。后面发现使用灰度图减少通道数(1272272)可以小幅提升dice。
  2. 关于图像像素值,灰度图的像素值范围为0到255,理论上进行zscore标准化(Global or per image)可以提升结果,但实测没什么效果(但收敛速度确实提升比较明显),而且在展示图像时会给我带来麻烦。于是就直接以uint8的形式(0~255)导入了。
  3. 关于图像处理,选择使用CLAHE直方图均衡(但是在zscore标准化下,此操作应该是无效的,可以自己随便找张图试下),肉眼上可以比较好地提高对比度。
  4. 关于数据增强,这里因为数据特别少(训练集仅152张图),使用平移和旋转两种数据增强方式,扩充训练集到3倍,验证集保持不变(8张),理论上在dataloader读取图片时以概率形式改变图片也可以达到扩充训练集的效果,但不知道是因为我训练参数没调好还是epoch太少,总之没有生效。
  5. 这里其实还用到了一个trick,由于心脏分割特征在二维不太明显,把一张切片与前四张切片stack到一起,以序列的形式导入模型(shape5272272,前面没有就补0),实际上只是修改了通道数,试图模拟3d的效果,结果可以小幅提升。

代价函数选择

  1. 分割问题常用的Loss有Dice, CrossEntropy, IOU, Focal 等。
    对于first stage(分割外轮廓,二分类问题),做下来Focal loss的表现比Binary CrossEntropy要好,最终可以达到95%的准确率。对于second stage,做下来lovasz_softmax loss的表现比CE,IOU,Dice要好,CE的话我这边都不能正常work,Dice的波动较大(但其实大概因为数据不均衡的原因,波动都挺大的),IOU的话表现不太好。

    Pytorch-UNet医学影像分割处理-踩坑及经验_第3张图片

  2. 这里评价指标函数均使用Dice,即看最终Dice怎么样来评价模型好坏,但我个人觉得可能用别的更好一点,不过为了给非专业的展示,可能这样更直观一点。

  3. 有一点比较坑的是,一开始做出来Dice只有0.2左右,但是实际图片看上去至少也能有0.5,后来发现是在计算的时候没有ignore黑色背景区域,把黑色部分也一起算上dice了,这样一来对模型的要求就很高了,因为其他颜色区域改动的话等于黑色部分也在改动,对loss的影响是翻倍的。最后仅对四种颜色分别计算dice然后取平均,视作一张图的dice。

优化器选择

  1. 优化器这一块其实我不太熟悉,因此踩了很大的坑,最开始我copy的代码里用的是SGD,然后我Learning rate一直调的挺大的,基本都是0.01~0.1量级的,然后我尝试Adam的时候也是用了这样的初始lr,导致我一度认为Adam效果不好。后来总算取了1e-4的lr之后尝试Adam,发现效果还是不行,然后才发现之前SGD在训练过程中调整学习率的句子没注释掉,导致学习率很快就趋于0了。本身Adam就是自适应学习率,手动调整的话很容易搞崩。

训练参数调整

  1. 这一块我也不太熟悉,我只能根据自己搜到的简单结论来操作,比如Adam最合适的学习率是1e-4,Batchsize在分割问题中一般越大越好等等。我实操过程中感觉Batchsize对我的模型的改进不是很大,可能还是因为数据集的问题。
  2. 实际过程中有一点比较困惑的是,明明我各种seed都固定好了,训练集验证集也是我手动划分的,但每次跑出来的结果还是不一样。我查了一番告诉我是GPU的原因,具体我也不知道,但最终收敛结果相差不会超过2%,我也就没管了。

你可能感兴趣的:(深度学习,python,深度学习,机器学习,计算机视觉)