StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间

目录

    • 前言
    • 效果比对
      • 其他第三方算法效果
      • 本文实现效果
    • 代码
      • train_encoder.py
      • encode_images_s1.py
    • 运行
    • 后记
      • 参数调整
      • 其他尝试
        • run_projector.py
        • projector.py

前言

还记得我们曾经使用stylegan-encoder寻找图片潜码来控制图片的生成.
到了StyleGAN2后,官方的代码自带了个 run_projector.py 来将图片投影到对应的潜码.
但是使用后发现其生成速度慢(所需迭代数高),生成的相似度不高,根本没第一代的 pbaylies/stylegan-encoder 好用.
让我们来对比下官方和 pbaylies/stylegan-encoder 的投影效果(注意最后一张是Target)

效果比对

StyleGAN2 run_projector.py 1~100iter 10~60s(忽略模型加载时间)
StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间_第1张图片
StyleGAN2 run_projector.py 1~1000iter(默认) 10~430s
StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间_第2张图片

StyleGAN encode_images.py 1~100iter 10~60s
StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间_第3张图片

其他第三方算法效果

有人会说,这个 robertluxemburg/stylegan2encoder 宣称"This is an experimental port of pbaylies/stylegan-encoder for NVlabs/stylegan2."
然后我试了下,效果是这样的.

StyleGAN2 encode_images.py 1~100iter 5~35s
StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间_第4张图片
虽然迭代多次可能效果会好起来,但是这并不能达到我们的目的.
对比源码发现,该储存库使用了pbaylies/stylegan-encoder的一些优化方式,但是没应用它微调resnet以寻找潜码初始值的方式.而正是这个方式使得它不需要大量迭代就能产生很好的效果.

本文实现效果

所以我们开始自行将pbaylies/stylegan-encoder的全套方法迁移到StyleGAN2中来.最后的效果是这样的.可以看到初始化时人脸已经很接近目标了,最后100迭代的结果也和官方训练了1000迭代的不相上下.

StyleGAN2 encode_images_s1.py 1~100iter 10~60s
StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间_第5张图片

代码

train_encoder.py

微调resnet训练反向网络,修改自SimJeg的代码.

import os
import numpy as np
import cv2

from keras.applications.imagenet_utils import preprocess_input
from keras.layers import Dense, Reshape
from keras.models import Sequential, Model, load_model
from keras.applications.resnet50 import ResNet50
from keras.optimizers import Adam

import pretrained_networks
import dnnlib.tflib as tflib


def get_batch(batch_size, Gs, image_size=224, Gs_minibatch_size=12, w_mix=None, latent_size=18):
    """
    Generate a batch of size n for the model to train
    returns a tuple (W, X) with W.shape = [batch_size, latent_size, 512] and X.shape = [batch_size, image_size, image_size, 3]
    If w_mix is not None, W = w_mix * W0 + (1 - w_mix) * W1 with
        - W0 generated from Z0 such that W0[:,i] = constant
        - W1 generated from Z1 such that W1[:,i] != constant

    Parametersget_batch
    ----------
    batch_size : int
        batch size
    Gs
        StyleGan2 generator
    image_size : int
    Gs_minibatch_size : int
        batch size for the generator
    w_mix : float

    Returns
    -------
    tuple
        dlatent W, images X
    """

    # Generate W0 from Z0
    Z0 = np.random.randn(batch_size, Gs.input_shape[1])
    W0 = Gs.components.mapping.run(Z0, None, minibatch_size=Gs_minibatch_size)

    if w_mix is None:
        W = W0
    else:
        # Generate W1 from Z1
        Z1 = np.random.randn(latent_size * batch_size, Gs.input_shape[1])
        W1 = Gs.components.mapping.run(Z1, None, minibatch_size=Gs_minibatch_size)
        W1 = np.array([W1[batch_size * i:batch_size * (i + 1), i] for i in range(latent_size)]).transpose((1, 0, 2))

        # Mix styles between W0 and W1
        W = w_mix * W0 + (1 - w_mix) * W1

    # Generate X
    X = Gs.components.synthesis.run(W, randomize_noise=True, minibatch_size=Gs_minibatch_size, print_progress=True,
                                    output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))

    # Preprocess images X for the Imagenet model
    X = np.array([cv2.resize(x, (image_size, image_size)) for x in X])
    X = preprocess_input(X.astype('float'))

    return W, X


def finetune(save_path, image_size=224, base_model=ResNet50, batch_size=2048, test_size=1024, n_epochs=6,
             max_patience=5, models_dir='models/stylegan2-ffhq-config-f.pkl'):
    """
    Finetunes a ResNet50 to predict W[:, 0]

    Parameters
    ----------
    save_path : str
        path where to save the Resnet
    image_size : int
    base_model : keras model
    batch_size :  int
    test_size : int
    n_epochs : int
    max_patience : int

    Returns
    -------
    None

    """

    assert image_size >= 224

    # Load StyleGan generator
    _, _, Gs = pretrained_networks.load_networks(models_dir)

    # Build model
    if os.path.exists(save_path):
        print('Loading pretrained network')
        model = load_model(save_path, compile=False)
    else:
        base = base_model(include_top=False, pooling='avg', input_shape=(image_size, image_size, 3))
        model = Sequential()
        model.add(base)
        model.add(Dense(512))

    model.compile(loss='mse', metrics=[], optimizer=Adam(3e-4))
    model.summary()

    # Create a test set
    print('Creating test set')
    W_test, X_test = get_batch(test_size, Gs)

    # Iterate on batches of size batch_size
    print('Training model')
    patience = 0
    best_loss = np.inf

    while (patience <= max_patience):
        W_train, X_train = get_batch(batch_size, Gs)
        model.fit(X_train, W_train[:, 0], epochs=n_epochs, verbose=True)
        loss = model.evaluate(X_test, W_test[:, 0])
        if loss < best_loss:
            print(f'New best test loss : {loss:.5f}')
            model.save(save_path)
            patience = 0
            best_loss = loss
        else:
            print(f'-------- test loss : {loss:.5f}')
            patience += 1


def finetune_18(save_path, base_model=None, image_size=224, batch_size=2048, test_size=1024, n_epochs=6,
                max_patience=8, w_mix=0.7, latent_size=18, models_dir='models/stylegan2-ffhq-config-f.pkl'):
    """
    Finetunes a ResNet50 to predict W[:, :]

    Parameters
    ----------
    save_path : str
        path where to save the Resnet
    image_size : int
    base_model : str
        path to the first finetuned ResNet50
    batch_size :  int
    test_size : int
    n_epochs : int
    max_patience : int
    w_mix : float

    Returns
    -------
    None

    """

    assert image_size >= 224
    if not os.path.exists(save_path):
        assert base_model is not None

    # Load StyleGan generator
    _, _, Gs = pretrained_networks.load_networks(models_dir)

    # Build model
    if os.path.exists(save_path):
        print('Loading pretrained network')
        model = load_model(save_path, compile=False)
    else:
        base_model = load_model(base_model)
        hidden = Dense(latent_size * 512)(base_model.layers[-1].input)
        outputs = Reshape((latent_size, 512))(hidden)
        model = Model(base_model.input, outputs)
        # Set initialize layer
        W, b = base_model.layers[-1].get_weights()
        model.layers[-2].set_weights([np.hstack([W] * latent_size), np.hstack([b] * latent_size)])

    model.compile(loss='mse', metrics=[], optimizer=Adam(1e-4))
    model.summary()

    # Create a test set
    print('Creating test set')
    W_test, X_test = get_batch(test_size, Gs, w_mix=w_mix, latent_size=latent_size)

    # Iterate on batches of size batch_size
    print('Training model')
    patience = 0
    best_loss = np.inf

    while (patience <= max_patience):
        W_train, X_train = get_batch(batch_size, Gs, w_mix=w_mix, latent_size=latent_size)
        model.fit(X_train, W_train, epochs=n_epochs, verbose=True)
        loss = model.evaluate(X_test, W_test)
        if loss < best_loss:
            print(f'New best test loss : {loss:.5f}')
            model.save(save_path)
            patience = 0
            best_loss = loss
        else:
            print(f'-------- test loss : {loss:.5f}')
            patience += 1


if __name__ == '__main__':
    finetune('data/resnet.h5')
    finetune_18('data/resnet_18.h5', 'data/resnet.h5', w_mix=0.8)

encode_images_s1.py

import os
import argparse
import pickle

from tqdm import tqdm
import PIL.Image
import numpy as np
import dnnlib
import dnnlib.tflib as tflib
from encoder_s1.generator_model import Generator
from encoder_s1.perceptual_model import PerceptualModel, load_images
from keras.models import load_model

import glob
import random


def split_to_batches(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]


def main():
    parser = argparse.ArgumentParser(
        description='Find latent representation of reference images using perceptual losses',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('src_dir', help='Directory with images for encoding')
    parser.add_argument('generated_images_dir', help='Directory for storing generated images')
    parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')
    parser.add_argument('--data_dir', default='data', help='Directory for storing optional models')
    parser.add_argument('--mask_dir', default='masks', help='Directory for storing optional masks')
    parser.add_argument('--load_last', default='', help='Start with embeddings from directory')
    parser.add_argument('--dlatent_avg', default='',
                        help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
    parser.add_argument('--model_url', default='models/stylegan2-ffhq-config-f.pkl',
                        help='Fetch a StyleGAN model to train on from this URL')
    parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
    parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)

    # Perceptual model params
    parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
    parser.add_argument('--resnet_image_size', default=224, help='Size of images for the Resnet model', type=int)
    parser.add_argument('--lr', default=0.02, help='Learning rate for perceptual model', type=float)
    parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
    parser.add_argument('--iterations', default=100, help='Number of optimization steps for each batch', type=int)
    parser.add_argument('--decay_steps', default=10,
                        help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
    parser.add_argument('--load_effnet', default='data/finetuned_effnet.h5',
                        help='Model to load for EfficientNet approximation of dlatents')
    parser.add_argument('--load_resnet', default='data/resnet_18.h5',
                        help='Model to load for ResNet approximation of dlatents')

    # Loss function options
    parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.',
                        type=float)
    parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int)
    parser.add_argument('--use_pixel_loss', default=1.5,
                        help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
    parser.add_argument('--use_mssim_loss', default=100, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.',
                        type=float)
    parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.',
                        type=float)
    parser.add_argument('--use_l1_penalty', default=1, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.',
                        type=float)

    # Generator params
    parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=bool)
    parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale',
                        type=bool)
    parser.add_argument('--clipping_threshold', default=2.0,
                        help='Stochastic clipping of gradient values outside of this threshold', type=float)

    # Masking params
    parser.add_argument('--load_mask', default=False, help='Load segmentation masks', type=bool)
    parser.add_argument('--face_mask', default=False, help='Generate a mask for predicting only the face area',
                        type=bool)
    parser.add_argument('--use_grabcut', default=True,
                        help='Use grabcut algorithm on the face mask to better segment the foreground', type=bool)
    parser.add_argument('--scale_mask', default=1.5, help='Look over a wider section of foreground for grabcut',
                        type=float)

    # Video params
    parser.add_argument('--video_dir', default='videos', help='Directory for storing training videos')
    parser.add_argument('--output_video', default=False, help='Generate videos of the optimization process', type=bool)
    parser.add_argument('--video_codec', default='MJPG', help='FOURCC-supported video codec name')
    parser.add_argument('--video_frame_rate', default=24, help='Video frames per second', type=int)
    parser.add_argument('--video_size', default=512, help='Video size in pixels', type=int)
    parser.add_argument('--video_skip', default=1, help='Only write every n frames (1 = write every frame)', type=int)

    # 获取到基本设置时,如果运行命令中传入了之后才会获取到的其他配置,不会报错;而是将多出来的部分保存起来,留到后面使用
    args, other_args = parser.parse_known_args()

    # learning rate衰减的steps
    args.decay_steps *= 0.01 * args.iterations  # Calculate steps as a percent of total iterations

    if args.output_video:
        import cv2
        synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=False),
                                minibatch_size=args.batch_size)

    # 找到src_dir下所有图片文件,加入ref_images列表(即:源图的列表;只有一个图片也可以)
    ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
    ref_images = list(filter(os.path.isfile, ref_images))

    if len(ref_images) == 0:
        raise Exception('%s is empty' % args.src_dir)

    # 创建工作目录
    os.makedirs(args.data_dir, exist_ok=True)
    os.makedirs(args.mask_dir, exist_ok=True)
    os.makedirs(args.generated_images_dir, exist_ok=True)
    os.makedirs(args.dlatent_dir, exist_ok=True)
    os.makedirs(args.video_dir, exist_ok=True)

    # Initialize generator and perceptual model
    tflib.init_tf()
    # 加载StyleGAN模型
    model_file = glob.glob(args.model_url)
    if len(model_file) == 1:
        model_file = open(model_file[0], "rb")
    else:
        raise Exception('Failed to find the model')
    generator_network, discriminator_network, Gs_network = pickle.load(model_file)

    # 加载Generator类,参与构建VGG16 perceptual model,用于调用(说是生成,更好理解)generated_image
    # generated_image通过perceptual_model转化为generated_img_features,参与计算loss
    generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold,
                          tiled_dlatent=args.tile_dlatents, model_res=args.model_res,
                          randomize_noise=args.randomize_noise)
    if (args.dlatent_avg != ''):
        generator.set_dlatent_avg(np.load(args.dlatent_avg))

    perc_model = None
    if (args.use_lpips_loss > 0.00000001):  # '--use_lpips_loss', default = 100
        # 加载VGG16 perceptual模型
        model_file = glob.glob('./models/vgg16_zhang_perceptual.pkl')
        if len(model_file) == 1:
            model_file = open(model_file[0], "rb")
        else:
            raise Exception('Failed to find the model')
        perc_model = pickle.load(model_file)

    # 创建VGG16 perceptual模型
    perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
    perceptual_model.build_perceptual_model(generator)

    ff_model = None
    # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
    # tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息
    # 把ref_images分割为若干批次,每个批次的大小为args.batch_size,分批使用perceptual_model.optimize()求解每个源图的dlatents的最优解
    # 对每一个源图,优化迭代的过程是从一个初始dlatents开始,在某个空间内,按正态分布取值,使用Adam优化器,逐步寻找使loss最小的dlatents,即:stochastic clipping方法
    for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images) // args.batch_size):
        # 读取每个批次中的文件名
        names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
        if args.output_video:
            video_out = {}
            for name in names:
                video_out[name] = cv2.VideoWriter(os.path.join(args.video_dir, f'{name}.avi'),
                                                  cv2.VideoWriter_fourcc(*args.video_codec), args.video_frame_rate,
                                                  (args.video_size, args.video_size))

        # 给源图及源图用VGG16生成的features赋值(这是计算loss的基准)
        perceptual_model.set_reference_images(images_batch)
        dlatents = None
        if (args.load_last != ''):  # load previous dlatents for initialization
            for name in names:
                dl = np.expand_dims(np.load(os.path.join(args.load_last, f'{name}.npy')), axis=0)
                if (dlatents is None):
                    dlatents = dl
                else:
                    dlatents = np.vstack((dlatents, dl))
        else:
            if (ff_model is None):
                if os.path.exists(args.load_resnet):
                    print("Loading ResNet Model:")
                    ff_model = load_model(args.load_resnet)
                    from keras.applications.resnet50 import preprocess_input
            if (ff_model is None):
                if os.path.exists(args.load_effnet):
                    import efficientnet
                    print("Loading EfficientNet Model:")
                    ff_model = load_model(args.load_effnet)
                    from efficientnet import preprocess_input
            if (ff_model is not None):  # predict initial dlatents with ResNet model
                dlatents = ff_model.predict(
                    preprocess_input(load_images(images_batch, image_size=args.resnet_image_size)))
        # 设置用于perceptual_model优化迭代的初始值dlatents,它是用resnet50或者efficientnet从源图预测得到的
        if dlatents is not None:
            generator.set_dlatents(dlatents)
        # 对每一个源图,用tqdm构造进度条,显示优化迭代的过程
        op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations)
        pbar = tqdm(op, leave=False, total=args.iterations)
        vid_count = 0
        best_loss = None
        best_dlatent = None
        # 用stochastic clipping方法,使用VGG16 perceptual_model进行优化迭代,迭代次数为iterations=args.iterations
        for loss_dict in pbar:
            pbar.set_description(" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v)
                                                                     for k, v in loss_dict.items()]))
            if best_loss is None or loss_dict["loss"] < best_loss:
                best_loss = loss_dict["loss"]
                best_dlatent = generator.get_dlatents()
            if args.output_video and (vid_count % args.video_skip == 0):
                batch_frames = generator.generate_images()
                for i, name in enumerate(names):
                    video_frame = PIL.Image.fromarray(batch_frames[i], 'RGB').resize((args.video_size, args.video_size),
                                                                                     PIL.Image.LANCZOS)
                    video_out[name].write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
            # 用stochastic clip方法更新dlatent_variable
            generator.stochastic_clip_dlatents()
        print(" ".join(names), " Loss {:.4f}".format(best_loss))

        if args.output_video:
            for name in names:
                video_out[name].release()

        # Generate images from found dlatents and save them
        generator.set_dlatents(best_dlatent)
        generated_images = generator.generate_images()
        generated_dlatents = generator.get_dlatents()
        for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, names):
            img = PIL.Image.fromarray(img_array, 'RGB')
            img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG')
            np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)

        generator.reset_dlatents()


if __name__ == "__main__":
    main()


然后将 pbaylies/stylegan-encoder 的 encoder 文件夹复制到项目目录下重命名为 encoder_s1.

运行

运行 python train_encoder.py 微调resnet训练反向网络(最好丢上服务器练,不然OOM警告),训练好的网络默认放在./data/中.
将待处理的图片丢进 ./images/ 中.
运行 python encode_images_s1.py images/ generated_images/ latent_representations/ --load_resnet data/resnet_18.h5 --batch_size 1 --iterations 100 寻找潜码,潜码会放在./latent_representations/,重生成的图片会放在./generated_images/.

后记

在上述例子中可能看不出来,但是对于有复杂背景的图,该方法比官方的方法要好得多,得出的结果会相对更照顾人脸的相似性而不是背景的相似性.

参数调整

微调resnet的时候感觉太慢了(Tesla P100要跑近12小时),于是将训练到一半(loss=0.04557)的模型拉出来试了下,最后对比发现效果没按默认参数老老实实训练的好.

由于微调resnet代码中的默认参数设置的每批样本迭代数较多,导致批次内过拟合,减慢训练速度,最终效果也比较差.我们按照pbaylies/stylegan-encoder中微调resnet的参数进行设置,达到了更好的效果.

finetune('data/resnet.h5', n_epochs=2, max_patience=1)
finetune_18('data/resnet_18.h5', 'data/resnet.h5', w_mix=0.8, n_epochs=2, max_patience=1)

其他尝试

尝试了对StyleGAN2官方自带的run_project.py加入初始化latent_code.发现官方的算法并不能直接利用初始化的latent_code.
StyleGAN2使用pbaylies/stylegan-encoder将图片投影到潜在空间_第6张图片
看这张可能明显点,run_project.py的算法并不认为resnet初始化的latent_code产生的图片比简笔画(p200)更像Target.
在这里插入图片描述

run_projector.py

def project_real_images(network_pkl, dataset_name, data_dir, num_images, num_snapshots):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    ff_model = load_model('data/resnet_18.h5')
    from keras.applications.resnet50 import preprocess_input

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)

        new_images = np.expand_dims(np.array(Image.fromarray(images[0].swapaxes(0,2).swapaxes(0,1)).resize((224, 224), Image.LANCZOS).getdata()).reshape((224, 224, 3)), 0)
        dlatents = ff_model.predict(
            preprocess_input(new_images))
        tflib.set_vars({proj._dlatents_var: dlatents})

        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=num_snapshots)

projector.py

        # tflib.set_vars({self._target_images_var: target_images, self._dlatents_var: np.tile(self._dlatent_avg, [self._minibatch_size, 1, 1])})
        tflib.set_vars({self._target_images_var: target_images})

你可能感兴趣的:(stylegan)