目录
克隆项目代码
配置环境
下载测试数据集和预训练模型
测试及训练
使用自己的数据集训练及测试
修改cyclegan生成图像的大小
如果太慢,也可以选择自己下载下来项目上传至服务器
git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
cd pytorch-CycleGAN-and-pix2pix
配置基础环境,在此我们使用python3.8,pytorch的版本直接使用了目前官网最新的版本。
conda create -n cyclegan python=3.8
conda activate cyclegan
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
配置项目需要的环境,作者已经将需要的包名放入了requirements.txt文件中,只需使用pip遍历安装即可。
pip install -r requirements.txt
作者给我们写好了下载代码。在此使用apple2orange做测试。
`apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`.
bash ./scripts/download_cyclegan_model.sh apple2orange #下载预训练模型
bash ./datasets/download_cyclegan_dataset.sh apple2orange # 下载数据集
如果服务器下载慢的话,可以自己下载,按照bash文件中的操作,自己操作即可。
模型下载地址:http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/apple2orange.pth
数据集下载地址:http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/apple2orange.zip
如果自行下载,则把下载好的模型文件,放入 ./checkpoints/apple2orange_pretrained 并命名为latest_net_G.pth
mkdir ./checkpoints/apple2orange_pretrained
把下载的数据集zip文件传入项目目录,之后做如下操作。
mkdir ./datasets/apple2orange
unzip unzip apple2orange.zip -d ./datasets/
rm apple2orange.zip
接下来就可以测试了,顺利的话一次就可以成功了。
python test.py --dataroot datasets/apple2orange/testA --name apple2orange_pretrained --model test --no_dropout
成功后看到的结果是这样子的
---------- Networks initialized -------------
[Network G] Total number of parameters : 11.378 M
-----------------------------------------------
creating web directory ./results/apple2orange_cyclegan/test_latest
processing (0000)-th image... ['datasets/apple2orange/testA/n07740461_10011.jpg']
processing (0005)-th image... ['datasets/apple2orange/testA/n07740461_10371.jpg']
processing (0010)-th image... ['datasets/apple2orange/testA/n07740461_10940.jpg']
processing (0015)-th image... ['datasets/apple2orange/testA/n07740461_111.jpg']
processing (0020)-th image... ['datasets/apple2orange/testA/n07740461_1131.jpg']
processing (0025)-th image... ['datasets/apple2orange/testA/n07740461_11651.jpg']
processing (0030)-th image... ['datasets/apple2orange/testA/n07740461_11891.jpg']
processing (0035)-th image... ['datasets/apple2orange/testA/n07740461_12071.jpg']
processing (0040)-th image... ['datasets/apple2orange/testA/n07740461_12360.jpg']
processing (0045)-th image... ['datasets/apple2orange/testA/n07740461_12751.jpg']
(base) root@interactive93561:/opt/data/private/shaohua/pytorch-CycleGAN-and-pix2pix#
结果图片会放在results目录下:./results/apple2orange_cyclegan/test_latest/images。大家就可以看到风格迁移后的图像了。
那如果想自己去训练一下这个网络,也很简单,设定好数据集路径以及名称,即可开始训练。
python train.py --dataroot ./datasets/apple2orange --name apple2orange_cyclegan --model cycle_gan
如果直接这么训练可能会报错,但不会停止。因为代码中做了网页的实时可视化,大家在训练执勤啊可以先运行如下代码,第一次进入的时候会慢一点。之后大家就可以在http://localhost:8097中看到训练中间的一些结果了。
python -m visdom.server
如果需要使用自己的数据进行训练,我们只需要在datasets中创建自己任务的文件夹,按照它之前的命名方式进行命名就行了。
目录格式为
--datasets
----apple2orange
------testA
------testB
------trainA
------trainB
其中trainA和trainB文件夹中放入希望互相转换的两个域 的数据即可。如果只是训练的话testA和testB两个文件夹可以不创建。
如果想要使训练的图像大小更大一点(官方默认的话,会将图像裁剪到256*256),如果使用正方形的图片,可以在./options/base_options.py 找到设置 --load_size 以及 --crop_size,进行修改,即可实现对尺寸的修改。但因为我需要的图像是长方形的,又懒得去改配置文件,于是直接定位一下它是在哪里修改图像尺寸的,最后找到了在./data/base_dataset.py文件中第81行def get_transform 方法中。因此,我直接在此处修改了该方法,以达到自己的目的。修改第85行的代码
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
中osize = [opt.load_size, opt.load_size]中的opt.load_size直接替换成了自己需要的大小。之后把91行到98行进行了注释,让其不再进行裁剪。大家可以根据自己的需求进行修改。
为了方便,修改存储的文件,仅存fake img,文件名和原始文件名相同。修改了./util/visualizer.py文件中的save_images方法。
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, use_wandb=False):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
# 23.4.16修改代码,仅保存fake的图像,用于数据增强。
# name = os.path.splitext(short_path)[0]
webpage.add_header(short_path)
ims, txts, links = [], [], []
ims_dict = {}
for label, im_data in visuals.items():
if label == 'fake':
im = util.tensor2im(im_data)
# image_name = '%s_%s.png' % (name, label)
image_name = short_path
save_path = os.path.join(image_dir, image_name)
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
ims.append(image_name)
txts.append(label)
links.append(image_name)
if use_wandb:
ims_dict[label] = wandb.Image(im)
webpage.add_images(ims, txts, links, width=width)
if use_wandb:
wandb.log(ims_dict)