edges2shoes数据集下载_GAN Compression3:复现方法

终于写到最后一部分了,前面读完了paper和代码,现在看下如何复现这个工作:

想了解这个工作的可以看下前面的paper和介绍:原paper​arxiv.org科技猛兽:GAN Compression1:原理分析​zhuanlan.zhihu.com5caa2c5ab0863ddadc5d7bbdd496e258.png科技猛兽:GAN Compression 2:代码解读​zhuanlan.zhihu.com5caa2c5ab0863ddadc5d7bbdd496e258.png

官方代码:https://github.com/mit-han-lab/gan-compression/blob/master/README.md​github.com

所需环境:Linux

Python 3

CPU or NVIDIA GPU + CUDA CuDNN

接下来正式开始!

首先把代码clone到本地:

git clone [email protected]:mit-han-lab/gan-compression.git

cdgan-compression

1.装一个PyTorch1.4,以及一些依赖库(torchvision):

细心的你可能发现代码里面有个requirement.txt,里面是这样的:

absl-py==0.9.0

blessings==1.7

certifi==2019.11.28

dominate==2.4.0

grpcio==1.16.1

Markdown==3.1.1

numpy==1.18.1

nvidia-ml-py3==7.352.0

olefile==0.46

opencv-python==4.2.0.32

Pillow==7.0.0

protobuf==3.11.3

psutil==5.7.0

scipy==1.4.1

six==1.14.0

tensorboard==2.0.0

tensorboardX==2.0

torch==1.4.0

torchvision==0.5.0

torchprofile==0.0.1

tqdm==4.42.1

Werkzeug==1.0.0

wget==3.2

不要怀疑,你没有眼花,这些都是需要先装好的~

如果你是pip选手,那就:

pip install -r requirements.txt

如果你是conda选手,请打开scripts文件夹,里面有一个可执行文件conda_deps.sh,你就:

scripts/conda_deps.sh

那这个conda_deps.sh是个什么玩意咧?

打开后发现是:

#!/usr/bin/env bash

set -ex

conda install pytorch==1.4.0 torchvision==0.5.0 -c pytorch

conda install tqdm scipy tensorboard

conda install -c conda-forge tensorboardx

pip install opencv-python dominate wget

还是这一堆库。

这里建议直接conda创建一个新的虚拟环境,专门用来跑这个实验,创建的方法是:

conda create -n your_env_name python=X.X(2.7、3.6等)

查看你都有哪些虚拟环境:

conda info -e

查看安装了哪些包:

conda list

接下来还要安装torchprofile:https://github.com/zhijian-liu/torchprofile​github.com

pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git

这个应该是计算macs的库。

2.到这里所依赖的库都装完了,现在该准备数据集了并尝试作者的预训练模型:

2.1 CycleGAN

Download the CycleGAN dataset (e.g., horse2zebra):

bash datasets/download_cyclegan_dataset.sh horse2zebra

获取ground-truth image的统计信息,以计算FID值(CycleGAN dataset使用FID作为评价指标)。

bash datasets/download_real_stat.sh horse2zebra A

bash datasets/download_real_stat.sh horse2zebra B

在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。

首先下载Pre-trained model(原模型和压缩之后的模型):

python scripts/download_model.py --model cyclegan --task horse2zebra --stage full

python scripts/download_model.py --model cyclegan --task horse2zebra --stage compressed

测试一下没压缩过的大模型:

bash scripts/cycle_gan/horse2zebra/test_full.sh

测试一下压缩后的模型:

bash scripts/cycle_gan/horse2zebra/test_compressed.sh

看一下这个执行文件的内容:

#!/usr/bin/env bash

python test.py --dataroot database/horse2zebra/valA \

--dataset_mode single \

--results_dir results-pretrained/cycle_gan/horse2zebra/compressed \

--config_str 16_16_32_16_32_32_16_16 \

--restore_G_path pretrained/cycle_gan/horse2zebra/compressed/latest_net_G.pth \

--need_profile \

--real_stat_path real_stat/horse2zebra_B.npz

最后看看延迟是多大:

bash scripts/pix2pix/edges2shoes-r/latency_full.sh

bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh

2.2 Pix2Pix

Download the pix2pix dataset (e.g., edges2shoes).

bash datasets/download_pix2pix_dataset.sh edges2shoes-r

获取ground-truth image的统计信息,以计算FID值(pix2pix dataset使用FID作为评价指标)。

bash datasets/download_real_stat.sh edges2shoes-r B

在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。

首先下载Pre-trained model(原模型和压缩之后的模型):

python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full

python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed

测试一下没压缩过的大模型:

bash scripts/pix2pix/edges2shoes-r/test_full.sh

测试一下压缩后的模型:

bash scripts/pix2pix/edges2shoes-r/test_compressed.sh

看一下这个执行文件的内容:

#!/usr/bin/env bash

python test.py --dataroot database/edges2shoes-r \

--results_dir results-pretrained/pix2pix/edges2shoes-r/compressed \

--restore_G_path pretrained/pix2pix/edges2shoes-r/compressed/latest_net_G.pth \

--config_str 32_32_48_32_48_48_16_16 \

--real_stat_path real_stat/edges2shoes-r_B.npz \

--need_profile --num_test 500

最后看看延迟是多大:

bash scripts/pix2pix/edges2shoes-r/latency_full.sh

bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh

2.3 GauGAN

Download the pix2pix dataset (e.g., edges2shoes).

bash datasets/download_real_stat.sh cityscapes A

在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。

首先下载Pre-trained model(原模型和压缩之后的模型):

python scripts/download_model.py --model gaugan --task cityscapes --stage full

python scripts/download_model.py --model gaugan --task cityscapes --stage compressed

测试一下没压缩过的大模型:

bash scripts/gaugan/cityscapes/test_full.sh

测试一下压缩后的模型:

bash scripts/gaugan/cityscapes/test_compressed.sh

看一下这个执行文件的内容:

#!/usr/bin/env bash

python test.py --dataroot database/cityscapes-origin \

--config_str 32_32_48_32_32_32_32_24 \

--model spade --dataset_mode cityscapes \

--results_dir results-pretrained/gaugan/cityscapes/compressed \

--ngf 48 --netG sub_mobile_spade \

--restore_G_path pretrained/gaugan/cityscapes/compressed/latest_net_G.pth \

--real_stat_path real_stat/cityscapes_A.npz \

--drn_path drn-d-105_ms_cityscapes.pth \

--cityscapes_path database/cityscapes-origin \

--table_path datasets/table.txt --need_profile

最后看看延迟是多大:

bash scripts/gaugan/cityscapes/latency_full.sh

bash scripts/gaugan/cityscapes/latency_compressed.sh

CitySpaces Dataset

这个数据集需要手动下载,下载地址在:Semantic Understanding of Urban Street Scenes​cityscapes-dataset.comedges2shoes数据集下载_GAN Compression3:复现方法_第1张图片

使用以下指令即可:

python datasets/get_trainIds.py database/cityscapes-origin/gtFine/

python datasets/prepare_cityscapes_dataset.py \

--gtFine_dir database/cityscapes-origin/gtFine \

--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \

--output_dir database/cityscapes \

--table_path datasets/table.txt

这些指令的含义是:download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zipand unzip them in the same folder。这个folder是database/cityscapes-origin。

所要用到的预训练模型来自(需要手动下载):http://go.yf.io/drn-cityscapes-models​go.yf.io

名字是DRN model drn-d-105_ms_cityscapes.pth。

CycleGAN Datasets:

Pix2pix Datasets:

3.Training:

根据paper,训练过程可以分为以下3部分:"Train a MobileNet Teacher Model"

"Pre-distillation"

"Fine-tuning the Best Model"

3.1 Pix2pix Model Compression:

以edges2shoes-r举例:

3.1.1 Train a MobileNet Teacher Model:

开始第一步:从头训练一个MobileNet-style teacher model:

bash scripts/pix2pix/edges2shoes-r/train_mobile.sh

train_mobile.sh里面是:

#!/usr/bin/env bash

python train.py --dataroot database/edges2shoes-r \

--model pix2pix \

--log_dir logs/pix2pix/edges2shoes-r/mobile \

--real_stat_path real_stat/edges2shoes-r_B.npz

以下2步选做:

作者为每个数据集提供了一个预训练模型:

python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage mobile

下载好了可以测试一下:

bash scripts/pix2pix/edges2shoes-r/test_mobile.sh

test_mobile.sh里面是:

#!/usr/bin/env bash

python test.py --dataroot database/edges2shoes-r \

--results_dir results-pretrained/pix2pix/edges2shoes-r/mobile \

--ngf 64 --netG mobile_resnet_9blocks \

--restore_G_path pretrained/pix2pix/edges2shoes-r/mobile/latest_net_G.pth \

--real_stat_path real_stat/edges2shoes-r_B.npz \

--need_profile --num_test 500

3.1.2 Distill and prune the original MobileNet-style model to make the model compact

bash scripts/pix2pix/edges2shoes-r/distill.sh

distill.sh里面是:

#!/usr/bin/env bash

python distill.py --dataroot database/edges2shoes-r \

--distiller resnet \

--log_dir logs/pix2pix/edges2shoes-r/distill \

--batch_size 4 \

--restore_teacher_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \

--restore_pretrained_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \

--restore_D_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_D.pth \

--real_stat_path real_stat/edges2shoes-r_B.npz

作者提供了一个蒸馏好的student model,下载:

python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage distill

测试:

bash scripts/pix2pix/edges2shoes-r/test_distill.sh

test_distill.sh里面是:

#!/usr/bin/env bash

python test.py --dataroot database/edges2shoes-r \

--results_dir results-pretrained/pix2pix/edges2shoes-r/distill \

--ngf 48 --netG mobile_resnet_9blocks \

--restore_G_path pretrained/pix2pix/edges2shoes-r/distill/latest_net_G.pth \

--real_stat_path real_stat/edges2shoes-r_B.npz \

--need_profile --num_test 500

3.1.3 "Once-for-all" Network Training

从蒸馏得到的student model训练出一个"once-for-all" network:

bash scripts/pix2pix/edges2shoes-r/train_supernet.sh

train_supernet.sh里面是:

#!/usr/bin/env bash

python train_supernet.py --dataroot database/edges2shoes-r \

--supernet resnet \

--log_dir logs/pix2pix/edges2shoes-r/supernet \

--batch_size 4 \

--restore_teacher_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \

--restore_student_G_path logs/pix2pix/edges2shoes-r/distill/checkpoints/latest_net_G.pth \

--restore_D_path logs/pix2pix/edges2shoes-r/distill/checkpoints/latest_net_D.pth \

--real_stat_path real_stat/edges2shoes-r_B.npz \

--nepochs 10 --nepochs_decay 30 \

--teacher_ngf 64 --student_ngf 48 \

--config_set channels-48

作者提供了一个训练好的"once-for-all" network,下载:

python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage supernet

3.1.4 Select the Best Model

Evaluate all the candidate sub-networks given a specific configuration (e.g., MAC, FID)

bash scripts/pix2pix/edges2shoes-r/search.sh

search.sh里面是:

#!/usr/bin/env bash

python search.py --dataroot database/edges2shoes-r \

--restore_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \

--output_path logs/pix2pix/edges2shoes-r/supernet/result.pkl \

--ngf 48 --batch_size 32 \

--config_set channels-48 \

--real_stat_path real_stat/edges2shoes-r_B.npz

结果以pickle的格式存储。 这种格式是list,list中的每个元素是dict:

{'config_str': $config_str, 'macs': $macs, 'fid'/'mIoU': $fid_or_mIoU}

比如说:

{'config_str': '32_32_48_32_48_48_16_16', 'macs': 4993843200, 'fid': 25.224261423597483}

'config_str' is a channel configuration description to identify a specific subnet within the "once-for-all" network。

选出最好的结构:

python select_arch.py --macs 5.7e9 --fid 30 \ # macs <= 5.7e9(10x), fid >= 30

--pkl_path logs/pix2pix/edges2shoes-r/supernet/result.pkl

3.1.5 Fine-tuning the Best Model

bash scripts/pix2pix/edges2shoes-r/finetune.sh 32_32_48_32_48_48_16_16

使用的是train_supernet.py,finetune.sh里面是:

#!/usr/bin/env bash

python train_supernet.py --dataroot database/edges2shoes-r \

--supernet resnet \

--log_dir logs/pix2pix/edges2shoes-r/finetune \

--batch_size 4 \

--restore_teacher_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \

--restore_student_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \

--restore_D_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_D.pth \

--real_stat_path real_stat/edges2shoes-r_B.npz \

--nepochs 5 --nepochs_decay 15 \

--teacher_ngf 64 --student_ngf 48 \

--config_str $1

3.1.6 Export the model:

bash scripts/pix2pix/edges2shoes-r/export.sh 32_32_48_32_48_48_16_16

3.2 CycleGAN Model Compression:

以horse2zebra举例:

3.2.1 Train a MobileNet Teacher Model:

开始第一步:从头训练一个MobileNet-style teacher model:

bash scripts/cycle_gan/horse2zebra/train_mobile.sh

train_mobile.sh里面是:

#!/usr/bin/env bash

python train.py --dataroot database/horse2zebra \

--model cycle_gan \

--log_dir logs/cycle_gan/horse2zebra/mobile \

--real_stat_A_path real_stat/horse2zebra_A.npz \

--real_stat_B_path real_stat/horse2zebra_B.npz

以下2步选做:

作者为每个数据集提供了一个预训练模型:

python scripts/download_model.py --model cycle_gan --task horse2zebra --stage mobile

下载好了可以测试一下:

bash scripts/cycle_gan/horse2zebra/test_mobile.sh

test_mobile.sh里面是:

#!/usr/bin/env bash

python test.py --dataroot database/horse2zebra/valA \

--dataset_mode single \

--results_dir results-pretrained/cycle_gan/horse2zebra/mobile \

--ngf 64 --netG mobile_resnet_9blocks \

--restore_G_path pretrained/cycle_gan/horse2zebra/mobile/latest_net_G.pth \

--need_profile \

--real_stat_path real_stat/horse2zebra_B.npz

3.2.2 Distill and prune the original MobileNet-style model to make the model compact

bash scripts/cycle_gan/horse2zebra/distill.sh

distill.sh里面是:

#!/usr/bin/env bash

python distill.py --dataroot database/horse2zebra \

--dataset_mode unaligned \

--distiller resnet \

--log_dir logs/cycle_gan/horse2zebra/distill \

--gan_mode lsgan \

--student_ngf 32 --ndf 64 \

--restore_teacher_G_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_G_A.pth \

--restore_pretrained_G_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_G_A.pth \

--restore_D_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_D_A.pth \

--real_stat_path real_stat/horse2zebra_B.npz \

--lambda_recon 10 \

--lambda_distill 0.01 \

--nepochs 100 --nepochs_decay 100 \

--save_epoch_freq 20

作者提供了一个蒸馏好的student model,下载:

python scripts/download_model.py --model cycle_gan --task horse2zebra --stage distill

测试:

bash scripts/cycle_gan/horse2zebra/test_distill.sh

test_distill.sh里面是:

#!/usr/bin/env bash

python test.py --dataroot database/horse2zebra/valA \

--dataset_mode single \

--results_dir results-pretrained/cycle_gan/horse2zebra/distill \

--ngf 32 --netG mobile_resnet_9blocks \

--restore_G_path pretrained/cycle_gan/horse2zebra/distill/latest_net_G.pth \

--need_profile \

--real_stat_path real_stat/horse2zebra_B.npz

3.2.3 "Once-for-all" Network Training

从蒸馏得到的student model训练出一个"once-for-all" network:

bash scripts/cycle_gan/horse2zebra/train_supernet.sh

train_supernet.sh里面是:

#!/usr/bin/env bash

python train_supernet.py --dataroot database/horse2zebra \

--dataset_mode unaligned \

--supernet resnet \

--log_dir logs/cycle_gan/horse2zebra/supernet \

--gan_mode lsgan \

--student_ngf 32 --ndf 64 \

--restore_teacher_G_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_G_A.pth \

--restore_student_G_path logs/cycle_gan/horse2zebra/distill/checkpoints/latest_net_G.pth \

--restore_D_path logs/cycle_gan/horse2zebra/distill/checkpoints/latest_net_D.pth \

--real_stat_path real_stat/horse2zebra_B.npz \

--lambda_recon 10 --lambda_distill 0.01 \

--nepochs 200 --nepochs_decay 200 \

--save_epoch_freq 20 \

--config_set channels-32

作者提供了一个训练好的"once-for-all" network,下载:

python scripts/download_model.py --model cycle_gan --task horse2zebra --stage supernet

3.2.4 Select the Best Model

Evaluate all the candidate sub-networks given a specific configuration (e.g., MAC, FID)

bash scripts/cycle_gan/horse2zebra/search.sh

search.sh里面是:

#!/usr/bin/env bash

python search.py --dataroot database/horse2zebra/valA \

--dataset_mode single \

--restore_G_path logs/cycle_gan/horse2zebra/supernet/checkpoints/latest_net_G.pth \

--output_path logs/cycle_gan/horse2zebra/supernet/result.pkl \

--ngf 32 --batch_size 32 \

--config_set channels-32 \

--real_stat_path real_stat/horse2zebra_B.npz

选出最好的结构:

bash scripts/cycle_gan/horse2zebra/search.sh

3.2.5 Fine-tuning the Best Model

bash scripts/cycle_gan/horse2zebra/finetune.sh 16_16_32_16_32_32_16_16

其实可以跳过这一步。

3.2.6 Export the model:

bash scripts/cycle_gan/horse2zebra/export.sh 16_16_32_16_32_32_16_16

4.Evaluation:

要计算FID的值,你需要一些Ground Truth的统计信息,使用以下指令:

python get_real_stat.py \

--dataroot database/edges2shoes-r \

--output_path real_stat/edges2shoes-r_B.npz \

--direction AtoB

对于 paired image-to-image translation pix2pix and GauGAN,可以使用上述方法。

对于 unpaired image-to-image translation CycleGAN,计算generated test image和real training+test images之间的FID得分。

附:各个模型与对应的数据集:

Pix2Pix:edges2shoes-r,map2sat,cityspaces

CycleGAN:horse2zebra

GauGAN:cityspaces

你可能感兴趣的:(edges2shoes数据集下载_GAN Compression3:复现方法)