FDA论文解读及代码实现

FDA论文解读及代码实现

论文解读/翻译: FDA
github代码: github
目的:实现从cityscapes到ACDC/night风格转换,并训练语义分割网络。

0.conda环境搭建

conda create -n FDA python=0.4.0
conda activate FDA
conda install pytorch=0.4.0 torchvision cuda91 -y -c pytorch
pip install tensorboard tensorboardX

1.准备数据

无论训练还是验证,都通过dataset中的list.txt文件读取数据,txt中为相应图片路径或名称。
readfile.py获得data_list.txt

import os
import os.path
import glob
from PIL import Image
import shutil

rootdir1='./pytorch-CycleGAN-and-pix2pix/datasets/ACDC/rbg_anon_trainvaltest/rgb_anon/night/train/'
rootdir2='./pytorch-CycleGAN-and-pix2pix/datasets/ACDC/rbg_anon_trainvaltest/rgb_anon/night/train_ref/'
valdir='./pytorch-CycleGAN-and-pix2pix/datasets/ACDC/rbg_anon_trainvaltest/rgb_anon/night/val/'
labeldir='./pytorch-CycleGAN-and-pix2pix/datasets/ACDC/gt_trainval/gt/night/val/'

folders=os.listdir(rootdir1)
with open('./FDA/dataset/ACDC_list/night/train.txt','w') as f:
        for folder in folders:
            path=os.path.join(rootdir1,folder)
            for dirname, pathnames, filenames in os.walk(path):
                for filename in filenames:
                    file=folder+'/'+filename+'\n'
                    print(file)
                    f.write(file)


folders=os.listdir(rootdir2)
with open('./FDA/dataset/ACDC_list/night/train_ref.txt','w') as f:
        for folder in folders:
            path=os.path.join(rootdir2,folder)
            for dirname, pathnames, filenames in os.walk(path):
                for filename in filenames:
                    file=folder+'/'+filename+'\n'
                    print(file)
                    f.write(file)
#for dirname, pathnames, filenames in os.walk(rootdir):

folders=os.listdir(labeldir)
with open('./FDA/dataset/ACDC_list/night/label.txt','w') as f:
        for folder in folders:
            path=os.path.join(labeldir,folder)
            for dirname, pathnames, filenames in os.walk(path):
                for filename in filenames:
                    file=folder+'/'+filename+'\n'
                    print(file)
                    f.write(file)
                
img_ids = [i_id.strip() for i_id in open('./FDA/dataset/ACDC_list/night/val.txt')]
with open('./FDA/dataset/ACDC_list/night/label.txt','w') as f:
    for i in range(0,106):
        name=img_ids[i]
        lbname = name.replace("rgb_anon", "gt_labelIds")
        lbname=lbname+'\n'
        f.write(lbname)
    #print(img_ids[0])

读取图像的路径见train_options.py,主要修改源图像路径’–data-dir’,源图像文件名txt的存放路径’–data-list’,目标图像路径’–data-dir-target’,目标图像文件名txt的存放路径’–data-list-target’。图像集均存放在’./pytorch-CycleGAN-and-pix2pix’ 。
通过’./data/init.py’载入图像,修改图片尺寸,目标图像ACDC大小1920x1080,在init.py改大小(‘cityscapes’对应的是目标图像大小,本来设置为1024x512,后来原大小的一半:960x540效果没有之前的好);源图像大小需要大于目标图像。再通过
‘./data/ACDC_night_dataset.py’中读取train.txt得到训练数据,’./data/ACDC_night_SSL.py’,‘./data/ACDC_night_dataset_label.py’ line32=34,写入label所在路径’gt’,并将train.txt中文件名替换为’gt_labelIds’读取label。

2.train using FDA (single beta)

需要训练三个beta不同的模型,用Deeplab训练。(训练beta=0.9时可以设置entW=0提高性能)

python3 train.py --snapshot-dir='./checkpoints/FDA/cityscapes2night/beta_0.01' --GPU='0' --init-weights='./checkpoints/FDA/init_weight/DeepLab_init.pth' --LB=0.01 --entW=0.005 --ita=2.0 --switch2entropy=0
python3 train.py --snapshot-dir='./checkpoints/FDA/cityscapes2night/beta_0.05' --GPU='0' --init-weights='./checkpoints/FDA/init_weight/DeepLab_init.pth' --LB=0.05 --entW=0.005 --ita=2.0 --switch2entropy=0
python3 train.py --snapshot-dir='./checkpoints/FDA/cityscapes2night/beta_0.09' --GPU='0' --init-weights='./checkpoints/FDA/init_weight/DeepLab_init.pth' --LB=0.09 --entW=0.005 --ita=2.0 --switch2entropy=0

模型权重各自存放在’./FDA/checkpoints/FDA/cityscapes2night/'下的各个文件夹。

3.Evaluation Segmentation Networks with Multi-band

Evaluation of the Segmentation Networks Adapted with Multi-band Transfer (multiple betas).对多个(3个)beta的FDA变换并训练得到的分割网络结果评估。
主要在test_options.py修改参数,evaluation_multi.py用ACDC的val图集验证,得到mIoU指标判断分割效果,mIoU越大分割效果越好。val分割结果存放在–save的路径下,一共19个类,json用原代码包cityscapes中的json。(模型权重引用不加后缀’.pth’)

python3 evaluation_multi.py --model='DeepLab' --save='./results' --restore-opt1="./checkpoints/FDA/cityscapes2night/beta_0.01/cityscapes_100000" --restore-opt2="./checkpoints/FDA/cityscapes2night/beta_0.05/cityscapes_100000" --restore-opt3="./checkpoints/FDA/cityscapes2night/beta_0.09/cityscapes_100000"

注意在evaluation_multi.py line149修改三个不同beta模型对应输出的权重,找到使mIoU最大的搭配

a,b=0.5,0.25
output=a*output1+b*output2+(1.0-a-b)*output3

另外在line152修改输出图像大小,应该是target目标图像大小,注意width和height互换。且由于版本问题,nn.functional.interpolation修改为nn.functional.upsample。

output=nn.functional.upsample(output, (1080,1920), mode='bilinear', align_corners=True).cpu().data[0].numpy()

第一次运行时,由于data/init.py中的目标图像大小未调整到原图的一半,且beta=0.05和0.09的模型在训练到65000时意外中断,所以效果不理想:mIoU19=22.92,mIoU16=25.08,mIoU13=27.22。cityscapes_100000_mIoU.txt存放在checkpoints存放.pth文件夹内。

4.get Pseudo Labels

修改getSudoLabel_multi.py line67修改权重a、b调整至与evaluation_multi.py一致;以及line70的图像大小调整至(540,960),即ACDC图像像素的一半,interpolation改为upsample。

a, b = 0.5, 0.25
output = a*output1 + b*output2 + (1.0-a-b)*output3

output = nn.functional.upsample(output, (540,960), mode='bilinear', align_corners=True).cpu().data[0].numpy()
python3 getSudoLabel_multi.py --model='DeepLab' --data-list-target='./dataset/ACDC_list/night/train.txt' --set='train' --restore-opt1="./checkpoints/FDA/cityscapes2night/beta_0.01/cityscapes_100000" --restore-opt2="./checkpoints/FDA/cityscapes2night/beta_0.05/cityscapes_100000" --restore-opt3="./checkpoints/FDA/cityscapes2night/beta_0.09/cityscapes_100000"

5.self-supervised training

python3 SStrain.py --model='DeepLab' --snapshot-dir='./checkpoints/FDA/SSmodel/' --save='./results/SSmodel' --init-weights='./checkpoints/FDA/init_weight/DeepLab_init.pth' --label-folder='cs_pseudo_label' --LB=0.01 --entW=0.005 --ita=2.0

经过自监督训练的模型存储在’./checkpoints/FDA/SSmodel/’

6.结果

调用evaluation_multi.py查看mIoU。

python3 evaluation_multi.py --model='DeepLab' --save='./results' --restore-opt1="./checkpoints/FDA/SSmodel/cityscapes_100000" 

同时修改evaluation_multi.py line149权重

a,b=1,0 #只在意restore-opt1的输出
output=a*output1+b*output2+(1.0-a-b)*output3

mIoU.txt结果存放在"./checkpoints/FDA/SSmodel/",效果并不理想。
FDA论文解读及代码实现_第1张图片

你可能感兴趣的:(深度学习,python,pytorch)