[极简]pytorch版Unet训练自己的数据集

一、准备数据集

VOCdevkit
    VOC2007
         JPEGImages
         SegmentationClass
         ImageSets
              Segmentation
                   test.txt
                   train.txt
                   trainval.txt
                   val.txt

把所有jpg原图片放到JPEGImages
把所有png标注图片放到SegmentationClass
运行voc2unet.py脚本

二、训练

train.py
关键参数
1、num_classes=(类别数+1)
2、dice_loss=True/False
# 种类少(几类)时,设置为True
# 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True
# 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False
3、pretraind=True/False(是否使用预训练权重)
model_path = r"model_data/unet_voc.pth"(预训练权重路径)
4、lr = 1e-4
Init_Epoch = 0
Interval_Epoch = 25
Batch_size = 2

lr = 1e-5
Interval_Epoch = 25
Epoch = 50
Batch_size = 2
主干特征提取网络特征通用,冻结训练可以加快训练速度
也可以在训练初期防止权值被破坏。
Init_Epoch为起始世代
Interval_Epoch为冻结训练的世代
Epoch总训练世代
提示OOM或者显存不足请调小Batch_size

三、测试

predict.py(可自动保存)

unet.py
video.py

你可能感兴趣的:(计算机视觉,人工智能,深度学习)