使用Cycle GAN训练自己的数据

目录

克隆项目代码

 配置环境

下载测试数据集和预训练模型

测试及训练

使用自己的数据集训练及测试

修改cyclegan生成图像的大小


克隆项目代码

如果太慢,也可以选择自己下载下来项目上传至服务器

git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
cd pytorch-CycleGAN-and-pix2pix

 配置环境

配置基础环境,在此我们使用python3.8,pytorch的版本直接使用了目前官网最新的版本。

conda create -n cyclegan python=3.8
conda activate cyclegan
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

配置项目需要的环境,作者已经将需要的包名放入了requirements.txt文件中,只需使用pip遍历安装即可。

pip install -r requirements.txt

下载测试数据集和预训练模型

作者给我们写好了下载代码。在此使用apple2orange做测试。

`apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`.

bash ./scripts/download_cyclegan_model.sh apple2orange #下载预训练模型
bash ./datasets/download_cyclegan_dataset.sh apple2orange # 下载数据集

如果服务器下载慢的话,可以自己下载,按照bash文件中的操作,自己操作即可。

模型下载地址:http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/apple2orange.pth

数据集下载地址:http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/apple2orange.zip

如果自行下载,则把下载好的模型文件,放入 ./checkpoints/apple2orange_pretrained 并命名为latest_net_G.pth

mkdir ./checkpoints/apple2orange_pretrained

把下载的数据集zip文件传入项目目录,之后做如下操作。 

mkdir ./datasets/apple2orange
unzip unzip apple2orange.zip -d ./datasets/
rm apple2orange.zip

测试及训练

接下来就可以测试了,顺利的话一次就可以成功了。

python test.py --dataroot datasets/apple2orange/testA --name apple2orange_pretrained --model test --no_dropout

成功后看到的结果是这样子的

---------- Networks initialized -------------
[Network G] Total number of parameters : 11.378 M
-----------------------------------------------
creating web directory ./results/apple2orange_cyclegan/test_latest
processing (0000)-th image... ['datasets/apple2orange/testA/n07740461_10011.jpg']
processing (0005)-th image... ['datasets/apple2orange/testA/n07740461_10371.jpg']
processing (0010)-th image... ['datasets/apple2orange/testA/n07740461_10940.jpg']
processing (0015)-th image... ['datasets/apple2orange/testA/n07740461_111.jpg']
processing (0020)-th image... ['datasets/apple2orange/testA/n07740461_1131.jpg']
processing (0025)-th image... ['datasets/apple2orange/testA/n07740461_11651.jpg']
processing (0030)-th image... ['datasets/apple2orange/testA/n07740461_11891.jpg']
processing (0035)-th image... ['datasets/apple2orange/testA/n07740461_12071.jpg']
processing (0040)-th image... ['datasets/apple2orange/testA/n07740461_12360.jpg']
processing (0045)-th image... ['datasets/apple2orange/testA/n07740461_12751.jpg']
(base) root@interactive93561:/opt/data/private/shaohua/pytorch-CycleGAN-and-pix2pix# 

结果图片会放在results目录下:./results/apple2orange_cyclegan/test_latest/images。大家就可以看到风格迁移后的图像了。

那如果想自己去训练一下这个网络,也很简单,设定好数据集路径以及名称,即可开始训练。

python train.py --dataroot ./datasets/apple2orange --name apple2orange_cyclegan --model cycle_gan

如果直接这么训练可能会报错,但不会停止。因为代码中做了网页的实时可视化,大家在训练执勤啊可以先运行如下代码,第一次进入的时候会慢一点。之后大家就可以在http://localhost:8097中看到训练中间的一些结果了。

python -m visdom.server

使用自己的数据集训练及测试

如果需要使用自己的数据进行训练,我们只需要在datasets中创建自己任务的文件夹,按照它之前的命名方式进行命名就行了。

目录格式为

--datasets
----apple2orange
------testA
------testB
------trainA
------trainB

其中trainA和trainB文件夹中放入希望互相转换的两个域 的数据即可。如果只是训练的话testA和testB两个文件夹可以不创建。

修改cyclegan生成图像的大小

如果想要使训练的图像大小更大一点(官方默认的话,会将图像裁剪到256*256),如果使用正方形的图片,可以在./options/base_options.py 找到设置 --load_size 以及 --crop_size,进行修改,即可实现对尺寸的修改。但因为我需要的图像是长方形的,又懒得去改配置文件,于是直接定位一下它是在哪里修改图像尺寸的,最后找到了在./data/base_dataset.py文件中第81行def get_transform 方法中。因此,我直接在此处修改了该方法,以达到自己的目的。修改第85行的代码

if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]

中osize = [opt.load_size, opt.load_size]中的opt.load_size直接替换成了自己需要的大小。之后把91行到98行进行了注释,让其不再进行裁剪。大家可以根据自己的需求进行修改。

修改save_images

为了方便,修改存储的文件,仅存fake img,文件名和原始文件名相同。修改了./util/visualizer.py文件中的save_images方法。

def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, use_wandb=False):
    """Save images to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
        image_path (str)         -- the string is used to create image paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    # 23.4.16修改代码,仅保存fake的图像,用于数据增强。
    # name = os.path.splitext(short_path)[0]

    webpage.add_header(short_path)
    ims, txts, links = [], [], []
    ims_dict = {}
    for label, im_data in visuals.items():
        if label == 'fake':
            im = util.tensor2im(im_data)
            # image_name = '%s_%s.png' % (name, label)
            image_name = short_path
            save_path = os.path.join(image_dir, image_name)
            util.save_image(im, save_path, aspect_ratio=aspect_ratio)
            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
            if use_wandb:
                ims_dict[label] = wandb.Image(im)
    webpage.add_images(ims, txts, links, width=width)
    if use_wandb:
        wandb.log(ims_dict)

你可能感兴趣的:(生成对抗网络,pytorch,深度学习)