pix2pixHD local训练

pix2pixHD是18年的一个精度比较好的生成网络,生成器方面主要是有两个网络组成(g1:global network,g2:local network)

pix2pixHD local训练_第1张图片

 

官网开源的代码默认训练是训练global network,也就是G1,就目前来说,G1本身就能达到一个精度不错的效果,不过既然pix2pixHD的精髓是g1+g2

那么,就需要联合训练一下,按照论文里的说法,需要先训练一个低分辨率的g1,然后将g1加入到g2中,然后只微调g2,最好是联合训练,将g1+g2都一起训练,能达到一个精度更好的效果,由于开源出来的代码默认是global network,因此要想训练local network之前需要训练一个g1,值得注意的是,g1的分辨率比起g2需要少一倍,比如g1如果是256,那么g2的分辨率就是512

想要训练local的话,需要加载一下之前训练的g1的模型权重:

model_global=GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer)
model_global.load_state_dict(torch.load('./checkpoints/beauty/latest_net_G.pth'))
print('gobal net load')
model_global=model_global.model 

从论文里面我们可以知道,先训练的g1:

python train.py --name tg --gpu_ids 0,2 --no_instance --label_nc 0 --loadSize 256   --batchSize 2 --lr 0.0002 --gan_mode ls --lambda_vgg 10 --netG global --ngf 48

然后结合g1来训练g2:

python train.py --name tl --gpu_ids 2 --no_instance --label_nc 0 --loadSize 512 --batchSize 1 --lr 0.0002 --gan_mode ls --lambda_vgg 10 --netG local --ngf 24 --niter_fix_global 1 --niter 1 --niter_decay 1

值得注意的是--niter_fix_global这个参数表示的是联合训练以前需要微调g2网络的迭代轮数,--ngf参数正好差2倍。

你可能感兴趣的:(深度学习,深度学习,人工智能,pix2pixHD)