[PyTorch][chapter 56][GAN 代码实现]

前言:

      整个工程分为两个文件:

      gan.py:  网络模型搭建

      main.py:  数据集生成,模型训练

  注意: 针对loss 中的生成器G的损失部分,Goodfellow 早期提出了版本1,后期提出了版本2

   V(G,D)= E_{x \sim p_r}[log(D(x))]+E_{x \sim p_g}[log(1-D(x))]

   E_{x \sim p_g}[log(1-D(x))]版本1

  E_{x \sim p_g}[-log(D(x))]  版本2 


目录:

  1.      GAN 网络结构
  2.      gan.py
  3.       main.py

一  GAN 网络结构

       

      1.1 训练D

            V(G,D)= E_{x \sim p_r}[log D(x)]+E_{x \sim p_g}[log (1-D(x)]

            最大化V

     1.2  训练G

            固定G, 最小化

             E_{x \sim p_z}[log(1-D(x)]


二 gan.py

   功能:

        实现 鉴别器D

        实现 生成器G

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023

@author: chengxf2
"""

import torch
from   torch import nn



#生成器模型
class Generator(nn.Module):
    
    def __init__(self):
        
        super(Generator,self).__init__()
        # z: [batch,input_features]
        h_dim = 400
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear( h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            )
        
    def forward(self, z):
        
        output = self.net(z)
        return output
    
#鉴别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        
        super(Discriminator,self).__init__()
        
        hDim=400
        # x: [batch,input_features]
        self.net = nn.Sequential(
            nn.Linear(2, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, 1),
            nn.Sigmoid()
            )
        
    def forward(self, x):
        
        #x:[batch,1]
        output = self.net(x)
        
        out = output.view(-1)
        return out
    




三  main.py

 功能:

      生成数据

       训练网络

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023

@author: chengxf2
"""


import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt


h_dim =400
batchSize = 256
viz = visdom.Visdom()

#viz = visdom.Visdom()


def weights_init(net):
   if isinstance(net, nn.Linear):
         # net.weight.data.normal_(0.0, 0.02)
         nn.init.kaiming_normal_(net.weight)
         net.bias.data.fill_(0)

def data_generator():
    """
    8- gaussian destribution

    Returns
    -------
    None.

    """
    scale = 2
    a = np.sqrt(2.0)
    centers =[
         (1,0),
         (-1,0),
         (0,1),
         (0,-1),
         (1/a,1/a),
         (1/a,-1/a),
         (-1/a, 1/a),
         (-1/a,-1/a)
        ]
    
    centers = [(scale*x, scale*y) for x,y in centers]
    
    while True:
        
         dataset =[]
         
         for i in range(batchSize):
             
             point = np.random.randn(2)*0.02
             center = random.choice(centers)
             point[0] += center[0]
             point[1] += center[1]
             dataset.append(point)
         dataset = np.array(dataset).astype(np.float32)
         dataset /=a
         #生成器函数是一个特殊的函数,可以返回一个迭代器
         yield dataset


def generate_image(D, G, xr, epoch):      #xr表示真实的sample
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))             # (16384, 2)
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    N = len(x)
    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points)      # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
   
    plt.contour(x, y, disc_map.reshape((N, N)).transpose())
    #plt.clabel(cs, inline=1, fontsize=10)
    plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchSize, 2)                 # [b, 2]
        samples = G(z).cpu().numpy()                # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    
         
         
def main():
  
    maxIter = 1000
    torch.manual_seed(10)
    np.random.seed(10)
    data_iter  = data_generator()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    G = Generator().to(device)
    D = Discriminator().to(device)
    G.apply(weights_init)
    D.apply(weights_init)
    optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
    K = 5
 
    

    
   
    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))

    for epoch in range(maxIter):
        
        #1: train Discrimator fistly
        for k in range(K):
            
            #1.1: train on real data
            xr = next(data_iter)
            xr = torch.from_numpy(xr).to(device)
            predr = D(xr)
            
            vr = torch.log(predr)
            #max(predr) == min(-predr)
            lossr = vr.mean()
            
            #1.2: train on fake data
            z = torch.randn(batchSize,2).to(device) #[b,2] 随机产生的噪声
            xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
            predf =D(xf) #min predf
            
            vf = torch.log(1e-4+1.0-predf)
            lossf = vf.mean()
            loss_D =-(lossr+lossf)
            
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            #print("\n Discriminator 训练结束 ",loss_D.item())
        
        # 2 train  Generator,max V(G,D)
        
        #2.1 train on fake data
        z = torch.randn(batchSize, 2)
        xf = G(z)
        predf =D(xf) #max predf
        
        vf = torch.log(1e-4+1.0-predf)
        loss_G= predf.mean()
        
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()
        
        if epoch %100 ==0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
            generate_image(D, G, xr, epoch)
            print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
         
        
 

    
    
    

if __name__ == "__main__":
    
    main()
         
    


三 训练效果

   里面的损失函数按照最早的论文里面的,跟其它版本有所区别

 效果:

      生成器G 训练的loss 最后稳定在一个固定值,无法更新生成器

      鉴别器: 因为生成器很弱,很容易鉴别出真实数据 和 fake 数据,导致loss 也迅速降低为0

      实际生成效果:

                 生成器生成出来的数据红色部分,和真实的数据分布绿色 有较大差距。

    生成器很弱。

[PyTorch][chapter 56][GAN 代码实现]_第1张图片

   

参考:

课时127 GAN实战-GD实现_哔哩哔哩_bilibili

https://www.cnblogs.com/cxq1126/p/13538409.html

你可能感兴趣的:(pytorch,生成对抗网络,人工智能)