因为目前的课题了解到了Cycle GAN,所以最近去学习了相关的一些知识。
目前网上绝大多数的代码都是https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix,所以下面的复现过程也都是基于此代码完成的。
Cycle GAN(Cycle Generative Adversarial Network)---即循环生成对抗网络,基于无监督学习,不需要配对的数据,作用是对两个域(不同风格)之间的图像进行相互转换。
一般直接从GitHub上就可直接下载https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
一般来说我配置环境喜欢用Anaconda的Anaconda Powershell Prompt,感觉比较容易。TIPs: 我是小白所以都是在WINDOWS下进行的。
首先安装部分可以参考网上的一些教程,大致就是下载Anaconda,然后开始菜单里面找到Anaconda Powershell Prompt
用conda 创建Python虚拟环境(在conda prompt环境下完成)
conda create -n environment_name python=X.X
注:该命令只适用于Windows环境;“environment_name”是要创建的环境名;“python=X.X”是选择的Python版本)
激活虚拟环境(在conda prompt环境下完成)
conda activate your_env_name
然后按照requirements.txt和代码里导入的模块进行安装包。
比如Cycle GAN下载后的requirements.txt文件:
torch>=1.4.0
torchvision>=0.5.0
dominate>=2.4.0
visdom>=0.1.8.8
wandb
先安装Pytorch,这个容易出错,这边建议使用这个代码
# CUDA 10.0
pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html
深度学习配环境有一些坑,比如你要是用GPU的话,需要提前安装显卡驱动支持的CUDA,我一般使用的是CUDA10.0,因为网上搜CUDA对于版本可以了解一些情况,这边不多加阐述,这边使用的是torch=1.4.0+torchvision=0.5.0,之后安装其他环境到时候可以直接换掉,后面的-f是使用的源地址,这边的换源方法网上也有很多,可以添加什么阿里、豆瓣、清华之类的镜像)
后面的其他包可以建一个XX.txt文件,比如需要
dominate>=2.4.0
visdom>=0.1.8.8
wandb
然后
pip install -r XX.txt
注意txt的地址,cd 改到txt的目录下就行。
基本上这些完成后就可以了
我一般会测试一下torch安装的是否成功,可以从conda prompt里测试
python
import torch
torch.cuda.is_available()
如果返回的是True就算成功了,False的话就需要重新看看哪步出问题了。
Cycle GAN的配置算是容易的,首先看一下结构
checkpoints
data
datasets
docs
imgs
models
options
results
scripts
util
test.py
train.py
需要创建checkpoints文件夹,数据存放在datasets里
datasets
| ├── # i.e. maps(将卫星图与导航图相互转换的数据集)
| | ├── trainA # 包含训练所需的A域图像 (i.e. 卫星图)
| | ├── trainB # 包含训练所需的B域图像 (i.e. 导航图)
| | ├── testA # 包含测试所需的A域图像 (i.e. 卫星图)
| | ├── testB # 包含测试所需的B域图像 (i.e. 导航图)
按照这样的结构存放好数据集,即可。网上有许多公共数据方便使用:http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/可以从这个网站下载,下载以后按上面的方式存放即可。
也可以使用自己的数据集,方式同理。
在实际运行代码的时候要注意的是,导入的包里面含有visdom,这个包有点类似tensorflow的tensorbord,是一个可视化界面,所以需要在运行前启动一下,可以仍在conda prompt里激活环境
conda activate your_env_name
启动visdom
python -m visdom.server
第一次启动会有点慢,显示It's Alive!表明成功
It's Alive!
INFO:root:Application Started
You can navigate to http://localhost:8097
http://localhost:8097就是可视化界面的地址,打开即可看到训练过程中的变化。
默认的运行代码为:
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
这是使用datasets里的maps数据集,在checkpoints文件夹内建立maps_cyclegan文件用于存放训练过程记录的模型,里面包含loss_log.txt、train_opt.txt和 .pth模型,以及wed文件夹,wed文件夹里存放这训练过程中的实例比对,这部分在网页上也可以看到。
至于此,是按照代码的默认值进行训练的,完成训练后在刚生成的 .pth模型终会有latest_net_G_A.pth、latest_net_G_B.pth、latest_net_D_A.pth、latest_net_D_B.pth,这些是训练过程中的最新模型,其中G_A.pth是用来将A域图像转换成B域图像的模型参数,G_B.pth是将B域图像转换成A域图像的模型参数,主要也是使用这两个。
在checkpoints文件夹内创建文件夹例如:maps_pretrained, 将刚刚的latest_net_G_A.pth移至文件夹并改名latest_net_G.pth,终端运行代码
python test.py --dataroot datasets/maps/testA --name maps_pretrained --model test --no_dropout
即可完成将trainA转换成trainB的过程,反之使用latest_net_G_B.pth。
epoch=20时的训练效果
训练过程
Cycle GAN的复现过程大致就是这样,后续更多内容可以查看options文件夹内的设置进行更改。