29.GAN

目录

  • GAN
  • WGAN
  • 代码
    • GAN
    • WGAN

29.GAN_第1张图片
training set是真实图像的分布Pr(x),generator根据Pg(x)分布,生成一个图像fake images,Discriminator通过学习两个分布,能够鉴别从两个分布中取出的样本图像。我们的目标是达到纳什均衡后,Pg(x)分布接近于Pr(x)。
29.GAN_第2张图片
真实的x样本分布,越大越好

x经过D函数,得到一个数D(x)
z经过G函数,得到X’g,再经过D函数,得到D(G(z))

29.GAN_第3张图片
29.GAN_第4张图片
当满足右边条件时,L等于2log2。

29.GAN_第5张图片
transposed convolution有时叫做DC。

GAN

z是一个隐藏变量,输入是可以根据具体任务随意设定的,输出是真实分布的维度,输出的2包含了x坐标和y坐标。
z的2是隐藏层的,第一个线性层的2是随意设置的,最后一个2是可视化的2维分布。

yield dataset会运行并且保存当前状态,下一次运行迭代器后,会直接从当前位置开始。

	torch.manual_seed(23)
    np.random.seed(23)

把它放在前面,是因为每次训练都具有随机性,结果会有点出入,放在前面可以减小随机性。

Gan核心部分
第一步,先训练Discriminator

    for epoch in range(50000):
        for _ in range(5):
            # 1.1 train on real data
            #numpy
            xr=next(data_iter)
            xr=torch.from_numpy(xr)#.cuda()
            #[b,2]=>[b,1]   1维sigmod
            predr=D(xr)
            # max predr
            lossr=-predr.mean()
            # 1.2 train on fake data
            z=torch.randn(batchsz,2)
            xf=G(z).detach()   #tf.stop_gradient()
            predf=D(xf)
            lossf=predf.mean()

            #aggregate all
            loss_D=lossr+lossf

            #optimizer
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

lossr使用负号,是因为梯度下降。
其中.detach()函数,相当于闸门,梯度只会传到xf这里就会断掉,不会往前穿,可以减少计算量,并且避免了把generator的梯度计算进来。

第二步,再训练Generator

    z=torch.randn(batchsz,2)
    xf=G(z)
    predf=D(xf)
    loss_G=-predf.mean()
    
    #optimize
    optim_G.zero_grad()
    loss_G.backward()
    optim_G.step()
    
    if epoch%100==0:
    	print(loss_D.item(),loss_G.item())

generator的目标就只有一个,xf在D中的概率越大越好。

29.GAN_第6张图片
8个黄色的点是高斯混合模型。
29.GAN_第7张图片
绿色是sample出的点。
现在的gan是train很不稳定,还没有收敛。

WGAN

29.GAN_第8张图片
在GAN中的第一步后,再加上1.3步gradient penalty

# 1.3 gradient penalty
    gp=gradient_penalty(D,xr,xf)

def gradient_penalty(D,xr,xf):
    '''

    :param D:
    :param xr:[b,2]
    :param xf:[b,2]
    :return:
    '''
    #[b,1]
    t=torch.rand(batchsz,1)
    #[b,2]=>[b,2]
    t=t.expand_as(xr)
    #interpolation
    mid=t*xr+(1-t)*xf
    #set it requires gradient
    mid.requires_grad_()

    pred=D(mid)
    grads=autograd.grad(outputs=pred,inputs=mid,
                        grad_outputs=torch.ones_like(pred),
                        create_graph=True,retain_graph=True,only_inputs=True)[0]
    gp=torch.pow(grads.norm(2,dim=1)-1,2).mean()

    return gp

29.GAN_第9张图片

代码

GAN

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random
from matplotlib import pylab as plt

h_dim=400
batchsz=512
viz=visdom.Visdom()


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.net=nn.Sequential(
            # z:[b,2]=>[b,2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2)
        )

    def forward(self,z):
        output=self.net(z)
        return output

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.net=nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        output=self.net(x)
        return output.view(-1)

def data_generator():
    '''
    8-gaussian mixture models
    :return:
    '''
    scale=2.
    centers=[
        (1,0),
        (-1,0),
        (0,1),
        (0,-1),
        (1./np.sqrt(2),1./np.sqrt(2)),
        (1./np.sqrt(2),-1./np.sqrt(2)),
        (-1./np.sqrt(2),1./np.sqrt(2)),
        (-1./np.sqrt(2),-1./np.sqrt(2))
    ]
    centers=[(scale*x,scale*y)for x,y in centers]

    while True:
        dataset=[]

        for i in range(batchsz):
            point=np.random.randn(2)*0.02
            center=random.choice(centers)
            #N(0,1)+center_x1/x2
            point[0]+=center[0]
            point[1] += center[1]
            dataset.append(point)

        dataset=np.array(dataset).astype(np.float32)
        dataset/=1.414
        yield dataset

def generator_image(D,G,xr,epoch):
    '''
    Generator and saves a plot of the true distribution,the generator and the cr itic
    '''
    N_POINTS=128
    RANGE=3
    plt.clf()

    points=np.zeros((N_POINTS,N_POINTS,2),dtype='float32')
    points[:,:,0]=np.linspace(-RANGE,RANGE,N_POINTS)[:,None]
    points[:,:,1]=np.linspace(-RANGE,RANGE,N_POINTS)[None,:]
    points=points.reshape((-1,2))

    #[16384,2]
    #print('p:',points.shape)

    #draw contour
    with torch.no_grad():
        points=torch.Tensor(points)
        disc_map=D(points).cpu().numpy()
    x=y=np.linspace(-RANGE,RANGE,N_POINTS)
    cs=plt.contour(x,y,disc_map.reshape((len(x),len(y))).transpose())
    plt.clabel(cs,inline=1,fontsize=10)

    #draw samples
    with torch.no_grad():
        z=torch.randn(batchsz,2)
        samples=G(z).cpu().numpy()
    plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
    plt.scatter(samples[:,0],samples[:,1],c='green',marker='+')

    viz.matplot(plt,win='contour',opts=dict(title='p(x):%d'%epoch))


def main():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter=data_generator()
    x=next(data_iter)
    #[b,2]
    # print(x.shape)

    G=Generator()
    D=Discriminator()
    # print(G)
    # print(D)
    optim_G=optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
    optim_D=optim.Adam(D.parameters(),lr=5e-4,betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss',opts=dict(title='loss',legend=['D','G']))



    for epoch in range(50000):
        for _ in range(5):
            # 1.1 train on real data
            #numpy
            xr=next(data_iter)
            xr=torch.from_numpy(xr)#.cuda()
            #[b,2]=>[b,1]   1维sigmod
            predr=D(xr)
            # max predr
            lossr=-predr.mean()
            # 1.2 train on fake data
            z=torch.randn(batchsz,2)
            xf=G(z).detach()   #tf.stop_gradient()
            predf=D(xf)
            lossf=predf.mean()

            #aggregate all
            loss_D=lossr+lossf

            #optimizer
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()


        z=torch.randn(batchsz,2)
        xf=G(z)
        predf=D(xf)
        loss_G=-predf.mean()

        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch%100==0:
            viz.line([[loss_D.item(),loss_G.item()]],[epoch],win='loss',update='append')

            print(loss_D.item(),loss_G.item())

            generator_image(D,G,xr,epoch)




if __name__ == '__main__':
    main()

WGAN

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random
from matplotlib import pylab as plt

h_dim=400
batchsz=512
viz=visdom.Visdom()


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.net=nn.Sequential(
            # z:[b,2]=>[b,2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2)
        )

    def forward(self,z):
        output=self.net(z)
        return output

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.net=nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        output=self.net(x)
        return output.view(-1)

def data_generator():
    '''
    8-gaussian mixture models
    :return:
    '''
    scale=2.
    centers=[
        (1,0),
        (-1,0),
        (0,1),
        (0,-1),
        (1./np.sqrt(2),1./np.sqrt(2)),
        (1./np.sqrt(2),-1./np.sqrt(2)),
        (-1./np.sqrt(2),1./np.sqrt(2)),
        (-1./np.sqrt(2),-1./np.sqrt(2))
    ]
    centers=[(scale*x,scale*y)for x,y in centers]

    while True:
        dataset=[]

        for i in range(batchsz):
            point=np.random.randn(2)*0.02
            center=random.choice(centers)
            #N(0,1)+center_x1/x2
            point[0]+=center[0]
            point[1] += center[1]
            dataset.append(point)

        dataset=np.array(dataset).astype(np.float32)
        dataset/=1.414
        yield dataset

def generator_image(D,G,xr,epoch):
    '''
    Generator and saves a plot of the true distribution,the generator and the cr itic
    '''
    N_POINTS=128
    RANGE=3
    plt.clf()

    points=np.zeros((N_POINTS,N_POINTS,2),dtype='float32')
    points[:,:,0]=np.linspace(-RANGE,RANGE,N_POINTS)[:,None]
    points[:,:,1]=np.linspace(-RANGE,RANGE,N_POINTS)[None,:]
    points=points.reshape((-1,2))

    #[16384,2]
    #print('p:',points.shape)

    #draw contour
    with torch.no_grad():
        points=torch.Tensor(points)
        disc_map=D(points).cpu().numpy()
    x=y=np.linspace(-RANGE,RANGE,N_POINTS)
    cs=plt.contour(x,y,disc_map.reshape((len(x),len(y))).transpose())
    plt.clabel(cs,inline=1,fontsize=10)

    #draw samples
    with torch.no_grad():
        z=torch.randn(batchsz,2)
        samples=G(z).cpu().numpy()
    plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
    plt.scatter(samples[:,0],samples[:,1],c='green',marker='+')

    viz.matplot(plt,win='contour',opts=dict(title='p(x):%d'%epoch))


def gradient_penalty(D,xr,xf):
    '''

    :param D:
    :param xr:[b,2]
    :param xf:[b,2]
    :return:
    '''
    #[b,1]
    t=torch.rand(batchsz,1)
    #[b,2]=>[b,2]
    t=t.expand_as(xr)
    #interpolation
    mid=t*xr+(1-t)*xf
    #set it requires gradient
    mid.requires_grad_()

    pred=D(mid)
    grads=autograd.grad(outputs=pred,inputs=mid,
                        grad_outputs=torch.ones_like(pred),
                        create_graph=True,retain_graph=True,only_inputs=True)[0]
    gp=torch.pow(grads.norm(2,dim=1)-1,2).mean()

    return gp


def main():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter=data_generator()
    x=next(data_iter)
    #[b,2]
    # print(x.shape)

    G=Generator()
    D=Discriminator()
    # print(G)
    # print(D)
    optim_G=optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
    optim_D=optim.Adam(D.parameters(),lr=5e-4,betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss',opts=dict(title='loss',legend=['D','G']))



    for epoch in range(50000):
        for _ in range(5):
            # 1.1 train on real data
            #numpy
            xr=next(data_iter)
            xr=torch.from_numpy(xr)#.cuda()
            #[b,2]=>[b,1]   1维sigmod
            predr=D(xr)
            # max predr
            lossr=-predr.mean()
            # 1.2 train on fake data
            z=torch.randn(batchsz,2)
            xf=G(z).detach()   #tf.stop_gradient()
            predf=D(xf)
            lossf=predf.mean()

            # 1.3 gradient penalty
            gp=gradient_penalty(D,xr,xf)

            #aggregate all
            loss_D=lossr+lossf+0.2*gp

            #optimizer
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()


        z=torch.randn(batchsz,2)
        xf=G(z)
        predf=D(xf)
        loss_G=-predf.mean()




        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch%100==0:
            viz.line([[loss_D.item(),loss_G.item()]],[epoch],win='loss',update='append')

            print(loss_D.item(),loss_G.item())

            generator_image(D,G,xr,epoch)




if __name__ == '__main__':
    main()

你可能感兴趣的:(pytorch,生成对抗网络,机器学习,深度学习)