U-NET安装与训练问题实录

参考文章:pytorch复现U-Net 及常见问题汇总(2021.11.14亲测可行)_奶盖芒果的博客-CSDN博客_pytorch 无法复现

1.安装过程

1.1 github代码

代码网址:https://github.com/milesial/Pytorch-UNet

1.2 环境配置

requirements:
matplotlib
numpy
Pillow
torch
torchvision
tqdm
wandb

requirements中未包含具体版本信息,笔者亲测torch11.3 + python3.6可用

1.3 网络数据集验证模型可用

1.3.1 数据集地址

kaggle:Carvana Image Masking Challenge | Kaggle

U-NET安装与训练问题实录_第1张图片

数据集内容,这里我们下载tran.zip及train_masks.zip文件即可

下载完毕后,数据样式如下:

U-NET安装与训练问题实录_第2张图片

 U-NET安装与训练问题实录_第3张图片

 注意:蒙版图片格式为.gif文件,如果为jpg、png蒙版训练会出错,笔者会在错误总结中具体介绍

1.3.2 网络数据集训练

打开train.py文件,修改数据集路径

U-NET安装与训练问题实录_第4张图片

如果采用conda环境,在命令行执行命令如下:

conda activate u-net # u-net执行环境名
cd E:\u-net\Pytorch-UNet-master # u-net网络文件夹
python train.py

一切顺利的话,可以看到如下界面:

U-NET安装与训练问题实录_第5张图片

 1.3.3 网络数据集预测

将程序目录中checkpoint文件夹下需要模型文件复制到predict.py文件所在目录

(1)修改predict.py中模型文件名称

U-NET安装与训练问题实录_第6张图片

或者执行命令时添加-m,后接权重文件名称即可

例如:python predict.py -m checkpoint_epoch5.pth

(2)如果采用conda环境,在命令行执行命令如下:

conda activate u-net # u-net执行环境名
cd E:\u-net\Pytorch-UNet-master # u-net网络文件夹
python predict.py -i 541_36.png --vi -v # -i后为预测图像名称

运行结果如下:

U-NET安装与训练问题实录_第7张图片

 恭喜,代表环境配置成功,可以继续下一阶段了!

2.训练自己的数据集

2.1 数据集准备

(1)图片采用3通道RGB.png图像

通道数通过如下图查看

U-NET安装与训练问题实录_第8张图片

 笔者输入位深度为24位图像正确,输入位深度为32图像报错

(2)图像蒙版为.gif格式二值图

mask文件名称为对应图像名称+_mask,如图所示:

U-NET安装与训练问题实录_第9张图片

注意:如果采用.png格式或者.jpg格式二值图会报错,错误原因详见错误汇总

2.2 模型训练

2.2.1 predict文件修改

如1.3.2所述,修改相应地址名,epoches,batch_size,lr即可

因为笔者是二分类问题-背景+波,所以不需要修改classes

U-NET安装与训练问题实录_第10张图片

2.3 模型预测

如1.3.3所述,修改相应图片名即可

3.问题汇总

3.1 模型预测结果全黑

3.1.1 原因1:输入数据集负样本含量过多

U-NET安装与训练问题实录_第11张图片

 如图,笔者第一次训练时未剔除无效数据,导致预测结果全黑

解决方法:增加数据集中正样本数量即可

3.2 模型训练错误

3.2.1 原因1:

U-NET安装与训练问题实录_第12张图片

RuntimeError:CUDA error: device-side assert triggered

报错如上图所示

解决方法:

(1)类别不匹配,修改classes数量即可

(2)mask格式采用.png,.jpg格式,需要修改为.gif格式

因为.png格式二值图以数组形式读入结果为[0, 255]

U-NET安装与训练问题实录_第13张图片

 而.gif格式以数组形式读入结果为[0,1]

U-NET安装与训练问题实录_第14张图片

 笔者猜测是这个原因导致错误的发生

你可能感兴趣的:(深度学习,pytorch,人工智能,python)