所谓的生成对抗网络,就是一种可以生成特定分布数据的神经网络模型
如上图所示, G A N GAN GAN网络结构中,最重要的是两个模块: G G G和 D D D,输入的数据,通过 G G G生成了一些伪数据 G ( z ) G(z) G(z),然后与真实数据 X X X一同输入到 D D D模块,然后进行判断是真实数据还是伪数据
从上面的训练目的可以看出,GAN的训练与传统的监督学习训练模型并不相同,传统的监督学习训练模型只能做数据的映射,输入的数据通过模型后得到输出值,随后构造损失函数衡量输出值与真实标签中间的差异,将这个差异值求导并采用梯度下降的方法更新模型中的参数,从而使得模型的输出逼近真实的标签值,在这种监督学习模型中,一个很核心的模块就是损失函数模块,而在GAN的训练模型中并没有损失函数模块,这是GAN训练模型与监督学习模型最大的不同,GAN训练模型输入的是随机数,通过 G G G模块输出了“伪数据”,但这里并不会构造损失函数去比较输出的“伪数据”与真实数据之间的差异,这是毫无意义的,“伪数据”与真实数据之间的差异通过 D D D模型获得,这里的 D D D模块就充当了监督学习中的损失函数模块的角色,得到差异值后,类似的对 G G G中的参数进行更新,从而使得“伪数据”逼近于真实的训练数据,需要注意的是这里的“逼近”并不是数值上的逼近,而是分布上的逼近
所谓DCGAN,就是利用卷积网络,实现GAN,即 D D D模块, G G G模块都使用卷积神经网络实现
从上图中可以看出,输入的是一个长度为100的张量,但是在Pytorch中的必须理解为一个四维张量——(1,100,1,1),最终得到的是一个 3 × 64 × 64 3\times64\times64 3×64×64的RGB图像
从上图中可以看出,输入的是一个 3 × 64 × 64 3\times64\times64 3×64×64的RGB图像,输出的是一个长度为2的向量,用于判断正负样本
Pytorch需要对Generator和Discriminator进行实现:
class Generator(nn.Module):
def __init__(self, nz=100, ngf=128, nc=3):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
class Discriminator(nn.Module):
def __init__(self, nc=3, ndf=128):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
GAN应用范围非常广,可以参看此链接