深度卷积生成对抗网络(DCGAN)|完整代码实现

生成对抗网络(GAN)由Ian Goodfellow在2014年提出。GAN通过训练两个神经网络解决了非监督问题。这两个网络一个称为生成网络,一个称为判别网络。

事实上,该网络的训练过程很有趣。我们可以借助一个例子来理解。最初,伪造者(生成网络)向警察(判别网络)展示假币,警察识别出币是假的,伪造者根据接收到的反馈制造了新的假币,如此重复多次,直到伪造者可以造出警察无法识别的假币。

深度卷积生成对抗网络(DCGAN)|完整代码实现_第1张图片

在GAN中,也就是最后得到了可以生成和真实图片非常类似的生成网络,以及可以高度识别伪造品的判别网络

训练的过程就是两个网络互相博弈的过程,最后达到纳什均衡

DCGAN则是GAN的一个变体,它在生成网络和判别网络中使用了卷积层和转置卷积层。

代码如下:

import torchimport torchvisionfrom torch import nnfrom torch import optimfrom torchvision import transformsfrom torchvision.datasets import CIFAR10import matplotlib.pyplot as plt
lr = 0.0002nz = 100 # noise dimensionimage_size = 64nc = 3 # chanel of img ngf = 64 # generate channelndf = 64 # discriminative channelbeta1 = 0.5BatchSize = 32max_epoch = 2 # 
transform=transforms.Compose([                transforms.Resize(64) ,                transforms.ToTensor(),                transforms.Normalize([0.5]*3,[0.5]*3)                ])
dataset=CIFAR10(root='cifar10/',transform=transform,download=True)
dataloader=torch.utils.data.DataLoader(dataset,BatchSize,shuffle = True)

def weights_init(m):    classname=m.__class__.__name__    if classname.find('Conv')!=-1:        m.weight.data.normal_(0.0,0.02)    elif classname.find('BatchNorm')!=-1:        m.weight.data.normal_(1.0,0.02)        m.bias.data.fill_(0)        # define modelclass Generator(nn.Module):    def __init__(self):        super(Generator,self).__init__()        self.main = nn.Sequential(            nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),            nn.BatchNorm2d(ngf*8),            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),            nn.BatchNorm2d(ngf*4),            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),            nn.BatchNorm2d(ngf*2),            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),            nn.BatchNorm2d(ngf),            nn.ReLU(True),
            nn.ConvTranspose2d(ngf,nc,4,2,1,bias=False),            nn.Tanh()        )    def forward(self,input):        output=self.main(input)        return outputnetG=Generator()netG.apply(weights_init)

        class Discriminator(nn.Module):    def __init__(self):        super(Discriminator,self).__init__()        self.main = nn.Sequential(            nn.Conv2d(nc,ndf,4,2,1,bias=False),            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),            nn.BatchNorm2d(ndf*2),            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),            nn.BatchNorm2d(ndf*4),            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),            nn.BatchNorm2d(ndf*8),            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*8,1,4,1,0,bias=False),            nn.Sigmoid()        )    def forward(self,input):        output=self.main(input)        return output.view(-1,1).squeeze(1)    netD=Discriminator()netD.apply(weights_init)

# optimizeroptimizerD = optim.Adam(netD.parameters(),lr,betas=(beta1,0.999))optimizerG = optim.Adam(netG.parameters(),lr,betas=(beta1,0.999))
# criterioncriterion = nn.BCELoss()


fix_noise = torch.randn(BatchSize,nz,1,1).normal_(0,1)
    if torch.cuda.is_available():    fix_noise = fix_noise.cuda()    netG.cuda()    netD.cuda()    criterion.cuda()             print('begin training, be patient')
for epoch in range(max_epoch):    for ii, data in enumerate(dataloader,0):        real,_=data        batch_size=real.size(0)        input=real        label = torch.ones(batch_size) # 1 for real        label2 = torch.zeros(batch_size)        noise = torch.randn(batch_size,nz,1,1).normal_(0,1)                if torch.cuda.is_available:            input = input.cuda()            label = label.cuda()            label2 = label.cuda()            noise = noise.cuda()                 # ----- train netd -----        netD.zero_grad()                ## train netd with real img        output=netD(input)        errorD_real=criterion(output,label)        errorD_real.backward()        D_x=output.data.mean()                ## train netd with fake img        fake_pic=netG(noise)        output2=netD(fake_pic.detach())                errorD_fake=criterion(output2,label2)        errorD_fake.backward()                D_x2=output2.data.mean()        error_D=errorD_real+errorD_fake                optimizerD.step()                # ------ train netg -------        netG.zero_grad()        fake_pic=netG(noise)        output=netD(fake_pic)        error_G=criterion(output,label)        error_G.backward()        D_G_z2=output.data.mean()        optimizerG.step()
#生成图片fake_u=netG(fix_noise)imgs = torchvision.utils.make_grid(fake_u*0.5+0.5).cpu()plt.imshow(imgs.permute(1,2,0).numpy()) plt.show()

GAN在多个应用领域都取得了许多令人振奋的结果,比如,利用CycleGAN进行图像转换,利用StackGAN自动从文本中制作逼真的图像,利用SRGAN通过预训练模型提高图像品质...

你可能感兴趣的:(pytorch,生成对抗网络,人工智能,神经网络)