Pytorch——生成式对抗网络的实例

1.GAN

        生成器的最终目标是要欺骗判别器,混淆真伪图像;而判别器的目标是发现他何时被欺骗了,同时告知生成器在生成图像的过程中可识别的错误。注意无论是判别器获胜还是生成器获胜,都不是字面意义上的获胜。两个网络都是基于彼此的训练结果来推动参数优化的。

        这项技术已经被证明可以产生一个生成器,他只从一个噪声和一个条件信号生成逼真的图像。这里的条件信号是指属性或者其他图片,例如:年轻的、女性的、戴着眼镜的。这样的生成器产生的图片可以满足人的审美。

2.CycleGAN

        CycleGAN是指循环生成式对抗网络,他可以将一个领域的图片转换为另一个领域的图像,而不需要我们显示的提供匹配对。

        我们来看一个CycleGAN的实例,将马转换为斑马。

        第一个生成器学习从一个从属于不同分布域的图像生成符合目标域的图像,因此判别器无法来分别出是否真的是斑马的图片;产生的假的斑马的图片通过另一个生成器发送到另一个判别器,由另一个判别器来判别真伪。创建这样的一个循环可以极大地使训练过程稳定。

2.1定义生成器

import torch
import torch.nn as nn

class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)

2.2加载数据集和模型张量参数

netG = ResNetGenerator()
model_path = '../data/p1ch2/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

Pytorch——生成式对抗网络的实例_第1张图片

 接下来把网络设置为eval

netG.eval()

Pytorch——生成式对抗网络的实例_第2张图片

2.3导入图片并预处理

from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256),
                                 transforms.ToTensor()])
img = Image.open("../data/p1ch2/horse.jpg")
img

Pytorch——生成式对抗网络的实例_第3张图片

2.4部署网络并执行 

img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
batch_out = netG(batch_t)
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
# out_img.save('../data/p1ch2/zebra.jpg')
out_img

Pytorch——生成式对抗网络的实例_第4张图片

 

你可能感兴趣的:(深度学习,人工智能,计算机视觉,pytorch)