生成对抗网络学习笔记-----GAN网络入门学习

今天在看gan网络,实际上很早就学了,但是现在要用到,又得重新学!哎。。。。。

首先明确一点gan网络是用来干什么的?它是用来生成一种没有的但是我们非常熟悉的东西,比如说人脸,就是用让网络自己根据大量的人脸照片,然后自己生成一个非常像人脸的照片(可以以假乱真的那种),由此生成对抗网络就出来了。

生成对抗网络学习笔记-----GAN网络入门学习_第1张图片

 gan网络结构图

废话就不说了,直接说网络输入什么?,输出又是什么?模型又是什么?损失函数大家自己去别的地方学习。上面的图看不懂直接看下面就行了。

一、生成模型

        大家想想,我们想要生成模型(就是gan网路结构图中的G)输出的是一张非常像人脸的图片,那么它的输出一定是图片的形状,那一定是[n,3,w,h],其中n表示图片的数量,3表示图片的通道数,w,h分别表示图片的宽和高度。

        然后我们再想一下,输入是什么?其实我觉得输入什么维度都是可以的,但是为了方便还是选择的是和图片一样的维度,也就是输出和输入是一样的,但是注意的是,输入的是一组正态分布的数据,如果是无规则的数据,让网络训练的就会慢。

        好了,我们输入输出都有了,那么选择什么模型呢,相信大家心里一下子就能想到一个模型,就是UNet网络,如果大家不熟悉可以自行去学习一下。它的网络结构图如下,输入输出的结果都是一样的形状。

生成对抗网络学习笔记-----GAN网络入门学习_第2张图片

import torch
import torch.nn as nn


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))


if __name__ == "__main__":
    x = torch.randn((2, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)

 二、判别模型

        刚才也说了,我们需要模型(就是gan网路结构图中的P)去学习大量的人脸图片,那么输入的维度肯定是 一张图片的大小形状,当然大家注意,这里面的输入可以看gan网路结构图中所示,输入模型中的图片是两种,一个是真实图片,一个是G网络生成的图片形状(我们习惯称为假图片)。        

        好了既然叫做判别模型,那么肯定有判断的效果在里面,因此我们的判别模型输出结果就是0-1,分别代表的就是假和真,也就是二分类模型,这个大家应该很清楚。现在我们就来看代码部分。

import torch
import torch.nn as nn


class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=None):
        super().__init__()
        if features is None:
            features = [64, 128, 256, 512]
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        x = self.model(x)
        return x


if __name__ == "__main__":
    x = torch.randn((1, 3, 256, 256))
    #y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(model)
    print(preds.shape)

 三、训练部分

 首先我们要将大量图片输入到判别模型中,伪代码就是out=model_panbie(image),再将这样的结果与1来进行损失计算,再将生成网络的输出结果输入到判别网络中得到的预测值和0进行损失计算,然后再进行反向传播,让生成网络不断的学习,最终得到一个图片非常像真的人脸图片,至此损失就会变的很小,模型就训练成功了。

你可能感兴趣的:(神经网络入门,生成对抗网络,学习,深度学习,pytorch)