终于写到最后一部分了,前面读完了paper和代码,现在看下如何复现这个工作:
想了解这个工作的可以看下前面的paper和介绍:原paperarxiv.org科技猛兽:GAN Compression1:原理分析zhuanlan.zhihu.com科技猛兽:GAN Compression 2:代码解读zhuanlan.zhihu.com
官方代码:https://github.com/mit-han-lab/gan-compression/blob/master/README.mdgithub.com
所需环境:Linux
Python 3
CPU or NVIDIA GPU + CUDA CuDNN
接下来正式开始!
首先把代码clone到本地:
git clone [email protected]:mit-han-lab/gan-compression.git
cdgan-compression
1.装一个PyTorch1.4,以及一些依赖库(torchvision):
细心的你可能发现代码里面有个requirement.txt,里面是这样的:
absl-py==0.9.0
blessings==1.7
certifi==2019.11.28
dominate==2.4.0
grpcio==1.16.1
Markdown==3.1.1
numpy==1.18.1
nvidia-ml-py3==7.352.0
olefile==0.46
opencv-python==4.2.0.32
Pillow==7.0.0
protobuf==3.11.3
psutil==5.7.0
scipy==1.4.1
six==1.14.0
tensorboard==2.0.0
tensorboardX==2.0
torch==1.4.0
torchvision==0.5.0
torchprofile==0.0.1
tqdm==4.42.1
Werkzeug==1.0.0
wget==3.2
不要怀疑,你没有眼花,这些都是需要先装好的~
如果你是pip选手,那就:
pip install -r requirements.txt
如果你是conda选手,请打开scripts文件夹,里面有一个可执行文件conda_deps.sh,你就:
scripts/conda_deps.sh
那这个conda_deps.sh是个什么玩意咧?
打开后发现是:
#!/usr/bin/env bash
set -ex
conda install pytorch==1.4.0 torchvision==0.5.0 -c pytorch
conda install tqdm scipy tensorboard
conda install -c conda-forge tensorboardx
pip install opencv-python dominate wget
还是这一堆库。
这里建议直接conda创建一个新的虚拟环境,专门用来跑这个实验,创建的方法是:
conda create -n your_env_name python=X.X(2.7、3.6等)
查看你都有哪些虚拟环境:
conda info -e
查看安装了哪些包:
conda list
接下来还要安装torchprofile:https://github.com/zhijian-liu/torchprofilegithub.com
pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git
这个应该是计算macs的库。
2.到这里所依赖的库都装完了,现在该准备数据集了并尝试作者的预训练模型:
2.1 CycleGAN
Download the CycleGAN dataset (e.g., horse2zebra):
bash datasets/download_cyclegan_dataset.sh horse2zebra
获取ground-truth image的统计信息,以计算FID值(CycleGAN dataset使用FID作为评价指标)。
bash datasets/download_real_stat.sh horse2zebra A
bash datasets/download_real_stat.sh horse2zebra B
在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。
首先下载Pre-trained model(原模型和压缩之后的模型):
python scripts/download_model.py --model cyclegan --task horse2zebra --stage full
python scripts/download_model.py --model cyclegan --task horse2zebra --stage compressed
测试一下没压缩过的大模型:
bash scripts/cycle_gan/horse2zebra/test_full.sh
测试一下压缩后的模型:
bash scripts/cycle_gan/horse2zebra/test_compressed.sh
看一下这个执行文件的内容:
#!/usr/bin/env bash
python test.py --dataroot database/horse2zebra/valA \
--dataset_mode single \
--results_dir results-pretrained/cycle_gan/horse2zebra/compressed \
--config_str 16_16_32_16_32_32_16_16 \
--restore_G_path pretrained/cycle_gan/horse2zebra/compressed/latest_net_G.pth \
--need_profile \
--real_stat_path real_stat/horse2zebra_B.npz
最后看看延迟是多大:
bash scripts/pix2pix/edges2shoes-r/latency_full.sh
bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh
2.2 Pix2Pix
Download the pix2pix dataset (e.g., edges2shoes).
bash datasets/download_pix2pix_dataset.sh edges2shoes-r
获取ground-truth image的统计信息,以计算FID值(pix2pix dataset使用FID作为评价指标)。
bash datasets/download_real_stat.sh edges2shoes-r B
在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。
首先下载Pre-trained model(原模型和压缩之后的模型):
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
测试一下没压缩过的大模型:
bash scripts/pix2pix/edges2shoes-r/test_full.sh
测试一下压缩后的模型:
bash scripts/pix2pix/edges2shoes-r/test_compressed.sh
看一下这个执行文件的内容:
#!/usr/bin/env bash
python test.py --dataroot database/edges2shoes-r \
--results_dir results-pretrained/pix2pix/edges2shoes-r/compressed \
--restore_G_path pretrained/pix2pix/edges2shoes-r/compressed/latest_net_G.pth \
--config_str 32_32_48_32_48_48_16_16 \
--real_stat_path real_stat/edges2shoes-r_B.npz \
--need_profile --num_test 500
最后看看延迟是多大:
bash scripts/pix2pix/edges2shoes-r/latency_full.sh
bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh
2.3 GauGAN
Download the pix2pix dataset (e.g., edges2shoes).
bash datasets/download_real_stat.sh cityscapes A
在训练之前,我们可以先试试作者给的Pre-trained model,看看效果如何。
首先下载Pre-trained model(原模型和压缩之后的模型):
python scripts/download_model.py --model gaugan --task cityscapes --stage full
python scripts/download_model.py --model gaugan --task cityscapes --stage compressed
测试一下没压缩过的大模型:
bash scripts/gaugan/cityscapes/test_full.sh
测试一下压缩后的模型:
bash scripts/gaugan/cityscapes/test_compressed.sh
看一下这个执行文件的内容:
#!/usr/bin/env bash
python test.py --dataroot database/cityscapes-origin \
--config_str 32_32_48_32_32_32_32_24 \
--model spade --dataset_mode cityscapes \
--results_dir results-pretrained/gaugan/cityscapes/compressed \
--ngf 48 --netG sub_mobile_spade \
--restore_G_path pretrained/gaugan/cityscapes/compressed/latest_net_G.pth \
--real_stat_path real_stat/cityscapes_A.npz \
--drn_path drn-d-105_ms_cityscapes.pth \
--cityscapes_path database/cityscapes-origin \
--table_path datasets/table.txt --need_profile
最后看看延迟是多大:
bash scripts/gaugan/cityscapes/latency_full.sh
bash scripts/gaugan/cityscapes/latency_compressed.sh
CitySpaces Dataset
这个数据集需要手动下载,下载地址在:Semantic Understanding of Urban Street Scenescityscapes-dataset.com
使用以下指令即可:
python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--table_path datasets/table.txt
这些指令的含义是:download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zipand unzip them in the same folder。这个folder是database/cityscapes-origin。
所要用到的预训练模型来自(需要手动下载):http://go.yf.io/drn-cityscapes-modelsgo.yf.io
名字是DRN model drn-d-105_ms_cityscapes.pth。
CycleGAN Datasets:
Pix2pix Datasets:
3.Training:
根据paper,训练过程可以分为以下3部分:"Train a MobileNet Teacher Model"
"Pre-distillation"
"Fine-tuning the Best Model"
3.1 Pix2pix Model Compression:
以edges2shoes-r举例:
3.1.1 Train a MobileNet Teacher Model:
开始第一步:从头训练一个MobileNet-style teacher model:
bash scripts/pix2pix/edges2shoes-r/train_mobile.sh
train_mobile.sh里面是:
#!/usr/bin/env bash
python train.py --dataroot database/edges2shoes-r \
--model pix2pix \
--log_dir logs/pix2pix/edges2shoes-r/mobile \
--real_stat_path real_stat/edges2shoes-r_B.npz
以下2步选做:
作者为每个数据集提供了一个预训练模型:
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage mobile
下载好了可以测试一下:
bash scripts/pix2pix/edges2shoes-r/test_mobile.sh
test_mobile.sh里面是:
#!/usr/bin/env bash
python test.py --dataroot database/edges2shoes-r \
--results_dir results-pretrained/pix2pix/edges2shoes-r/mobile \
--ngf 64 --netG mobile_resnet_9blocks \
--restore_G_path pretrained/pix2pix/edges2shoes-r/mobile/latest_net_G.pth \
--real_stat_path real_stat/edges2shoes-r_B.npz \
--need_profile --num_test 500
3.1.2 Distill and prune the original MobileNet-style model to make the model compact
bash scripts/pix2pix/edges2shoes-r/distill.sh
distill.sh里面是:
#!/usr/bin/env bash
python distill.py --dataroot database/edges2shoes-r \
--distiller resnet \
--log_dir logs/pix2pix/edges2shoes-r/distill \
--batch_size 4 \
--restore_teacher_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \
--restore_pretrained_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \
--restore_D_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_D.pth \
--real_stat_path real_stat/edges2shoes-r_B.npz
作者提供了一个蒸馏好的student model,下载:
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage distill
测试:
bash scripts/pix2pix/edges2shoes-r/test_distill.sh
test_distill.sh里面是:
#!/usr/bin/env bash
python test.py --dataroot database/edges2shoes-r \
--results_dir results-pretrained/pix2pix/edges2shoes-r/distill \
--ngf 48 --netG mobile_resnet_9blocks \
--restore_G_path pretrained/pix2pix/edges2shoes-r/distill/latest_net_G.pth \
--real_stat_path real_stat/edges2shoes-r_B.npz \
--need_profile --num_test 500
3.1.3 "Once-for-all" Network Training
从蒸馏得到的student model训练出一个"once-for-all" network:
bash scripts/pix2pix/edges2shoes-r/train_supernet.sh
train_supernet.sh里面是:
#!/usr/bin/env bash
python train_supernet.py --dataroot database/edges2shoes-r \
--supernet resnet \
--log_dir logs/pix2pix/edges2shoes-r/supernet \
--batch_size 4 \
--restore_teacher_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \
--restore_student_G_path logs/pix2pix/edges2shoes-r/distill/checkpoints/latest_net_G.pth \
--restore_D_path logs/pix2pix/edges2shoes-r/distill/checkpoints/latest_net_D.pth \
--real_stat_path real_stat/edges2shoes-r_B.npz \
--nepochs 10 --nepochs_decay 30 \
--teacher_ngf 64 --student_ngf 48 \
--config_set channels-48
作者提供了一个训练好的"once-for-all" network,下载:
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage supernet
3.1.4 Select the Best Model
Evaluate all the candidate sub-networks given a specific configuration (e.g., MAC, FID)
bash scripts/pix2pix/edges2shoes-r/search.sh
search.sh里面是:
#!/usr/bin/env bash
python search.py --dataroot database/edges2shoes-r \
--restore_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \
--output_path logs/pix2pix/edges2shoes-r/supernet/result.pkl \
--ngf 48 --batch_size 32 \
--config_set channels-48 \
--real_stat_path real_stat/edges2shoes-r_B.npz
结果以pickle的格式存储。 这种格式是list,list中的每个元素是dict:
{'config_str': $config_str, 'macs': $macs, 'fid'/'mIoU': $fid_or_mIoU}
比如说:
{'config_str': '32_32_48_32_48_48_16_16', 'macs': 4993843200, 'fid': 25.224261423597483}
'config_str' is a channel configuration description to identify a specific subnet within the "once-for-all" network。
选出最好的结构:
python select_arch.py --macs 5.7e9 --fid 30 \ # macs <= 5.7e9(10x), fid >= 30
--pkl_path logs/pix2pix/edges2shoes-r/supernet/result.pkl
3.1.5 Fine-tuning the Best Model
bash scripts/pix2pix/edges2shoes-r/finetune.sh 32_32_48_32_48_48_16_16
使用的是train_supernet.py,finetune.sh里面是:
#!/usr/bin/env bash
python train_supernet.py --dataroot database/edges2shoes-r \
--supernet resnet \
--log_dir logs/pix2pix/edges2shoes-r/finetune \
--batch_size 4 \
--restore_teacher_G_path logs/pix2pix/edges2shoes-r/mobile/checkpoints/latest_net_G.pth \
--restore_student_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \
--restore_D_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_D.pth \
--real_stat_path real_stat/edges2shoes-r_B.npz \
--nepochs 5 --nepochs_decay 15 \
--teacher_ngf 64 --student_ngf 48 \
--config_str $1
3.1.6 Export the model:
bash scripts/pix2pix/edges2shoes-r/export.sh 32_32_48_32_48_48_16_16
3.2 CycleGAN Model Compression:
以horse2zebra举例:
3.2.1 Train a MobileNet Teacher Model:
开始第一步:从头训练一个MobileNet-style teacher model:
bash scripts/cycle_gan/horse2zebra/train_mobile.sh
train_mobile.sh里面是:
#!/usr/bin/env bash
python train.py --dataroot database/horse2zebra \
--model cycle_gan \
--log_dir logs/cycle_gan/horse2zebra/mobile \
--real_stat_A_path real_stat/horse2zebra_A.npz \
--real_stat_B_path real_stat/horse2zebra_B.npz
以下2步选做:
作者为每个数据集提供了一个预训练模型:
python scripts/download_model.py --model cycle_gan --task horse2zebra --stage mobile
下载好了可以测试一下:
bash scripts/cycle_gan/horse2zebra/test_mobile.sh
test_mobile.sh里面是:
#!/usr/bin/env bash
python test.py --dataroot database/horse2zebra/valA \
--dataset_mode single \
--results_dir results-pretrained/cycle_gan/horse2zebra/mobile \
--ngf 64 --netG mobile_resnet_9blocks \
--restore_G_path pretrained/cycle_gan/horse2zebra/mobile/latest_net_G.pth \
--need_profile \
--real_stat_path real_stat/horse2zebra_B.npz
3.2.2 Distill and prune the original MobileNet-style model to make the model compact
bash scripts/cycle_gan/horse2zebra/distill.sh
distill.sh里面是:
#!/usr/bin/env bash
python distill.py --dataroot database/horse2zebra \
--dataset_mode unaligned \
--distiller resnet \
--log_dir logs/cycle_gan/horse2zebra/distill \
--gan_mode lsgan \
--student_ngf 32 --ndf 64 \
--restore_teacher_G_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_G_A.pth \
--restore_pretrained_G_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_G_A.pth \
--restore_D_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_D_A.pth \
--real_stat_path real_stat/horse2zebra_B.npz \
--lambda_recon 10 \
--lambda_distill 0.01 \
--nepochs 100 --nepochs_decay 100 \
--save_epoch_freq 20
作者提供了一个蒸馏好的student model,下载:
python scripts/download_model.py --model cycle_gan --task horse2zebra --stage distill
测试:
bash scripts/cycle_gan/horse2zebra/test_distill.sh
test_distill.sh里面是:
#!/usr/bin/env bash
python test.py --dataroot database/horse2zebra/valA \
--dataset_mode single \
--results_dir results-pretrained/cycle_gan/horse2zebra/distill \
--ngf 32 --netG mobile_resnet_9blocks \
--restore_G_path pretrained/cycle_gan/horse2zebra/distill/latest_net_G.pth \
--need_profile \
--real_stat_path real_stat/horse2zebra_B.npz
3.2.3 "Once-for-all" Network Training
从蒸馏得到的student model训练出一个"once-for-all" network:
bash scripts/cycle_gan/horse2zebra/train_supernet.sh
train_supernet.sh里面是:
#!/usr/bin/env bash
python train_supernet.py --dataroot database/horse2zebra \
--dataset_mode unaligned \
--supernet resnet \
--log_dir logs/cycle_gan/horse2zebra/supernet \
--gan_mode lsgan \
--student_ngf 32 --ndf 64 \
--restore_teacher_G_path logs/cycle_gan/horse2zebra/mobile/checkpoints/latest_net_G_A.pth \
--restore_student_G_path logs/cycle_gan/horse2zebra/distill/checkpoints/latest_net_G.pth \
--restore_D_path logs/cycle_gan/horse2zebra/distill/checkpoints/latest_net_D.pth \
--real_stat_path real_stat/horse2zebra_B.npz \
--lambda_recon 10 --lambda_distill 0.01 \
--nepochs 200 --nepochs_decay 200 \
--save_epoch_freq 20 \
--config_set channels-32
作者提供了一个训练好的"once-for-all" network,下载:
python scripts/download_model.py --model cycle_gan --task horse2zebra --stage supernet
3.2.4 Select the Best Model
Evaluate all the candidate sub-networks given a specific configuration (e.g., MAC, FID)
bash scripts/cycle_gan/horse2zebra/search.sh
search.sh里面是:
#!/usr/bin/env bash
python search.py --dataroot database/horse2zebra/valA \
--dataset_mode single \
--restore_G_path logs/cycle_gan/horse2zebra/supernet/checkpoints/latest_net_G.pth \
--output_path logs/cycle_gan/horse2zebra/supernet/result.pkl \
--ngf 32 --batch_size 32 \
--config_set channels-32 \
--real_stat_path real_stat/horse2zebra_B.npz
选出最好的结构:
bash scripts/cycle_gan/horse2zebra/search.sh
3.2.5 Fine-tuning the Best Model
bash scripts/cycle_gan/horse2zebra/finetune.sh 16_16_32_16_32_32_16_16
其实可以跳过这一步。
3.2.6 Export the model:
bash scripts/cycle_gan/horse2zebra/export.sh 16_16_32_16_32_32_16_16
4.Evaluation:
要计算FID的值,你需要一些Ground Truth的统计信息,使用以下指令:
python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB
对于 paired image-to-image translation pix2pix and GauGAN,可以使用上述方法。
对于 unpaired image-to-image translation CycleGAN,计算generated test image和real training+test images之间的FID得分。
附:各个模型与对应的数据集:
Pix2Pix:edges2shoes-r,map2sat,cityspaces
CycleGAN:horse2zebra
GauGAN:cityspaces