使用SAGAN生成二次元人物头像(GAN生成对抗网络)--pytorch实现

        这是训练250epoch左右的成果。

使用SAGAN生成二次元人物头像(GAN生成对抗网络)--pytorch实现_第1张图片

之前的文章里面,我们使用了残差网络的形式实现生成器与辨别器,它理论上可以实现很不错的效果,但有一个很致命的缺点,就是训练太慢,很难见到成果。

        这一次,我们实现了一个利用自注意力机制制作的对抗生成网络。自注意力机制是我们在深度学习道路上,除了RNN,CNN以外,不得不了解的一种模块。非常有意思。简而言之,这个模块,相比于之前单纯使用卷积网络的GAN,它更加能注重上下文,举个例子,在生成人物眼睛的时候,它会注意到鼻子,头发等其他部位,从而将眼睛放在合适的位置,总之,能更好的学习到整体特征。

以下是生成器的代码。注意attn1与attn2层都是我们的自注意力模块,其他是我们所熟悉的DCGAN中使用过的反卷积。默认生成64*64像素的图片,如果想修改图片大小,请修改image_size,以及,后面层次中的channel大小以及selfattention的参数。

class Generator(nn.Module):
  def __init__(self, image_size = 64, z_dim = 100, conv_dim =64):
    super().__init__()
    repeat_num = int(np.log2(image_size)) - 3
    mult = 2 ** repeat_num

    self.l1 = nn.Sequential(
        spectral_norm(nn.ConvTranspose2d(in_channels = z_dim, out_channels = conv_dim * mult, kernel_size = 4)),
        nn.LayerNorm([512, 4, 4]),
        nn.ReLU()
    )

    curr_dim = conv_dim * mult
    self.l2 = nn.Sequential(
        spectral_norm(nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, 2, 1)),
        nn.LayerNorm([256, 8, 8]),

        nn.ReLU()
    )

    curr_dim = curr_dim // 2
    self.l3 = nn.Sequential(
        spectral_norm(nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, 2, 1)),
        nn.LayerNorm([128, 16, 16]),
        nn.ReLU()
    )

    curr_dim = curr_dim // 2
    self.l4 = nn.Sequential(
        spectral_norm(nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, 2, 1)),
        nn.LayerNorm([64, 32, 32]),
        nn.ReLU()

    )
    self.last = nn.Sequential(
        nn.ConvTranspose2d(64, 3, 4, 2, 1),
        nn.Tanh()
        )
    self.attn1 = selfattention(128)
    self.attn2 = selfattention(64)
  def forward(self, input):
    input = input.view(input.size(0), input.size(1), 1, 1)
    out = self.l1(input)

    out = self.l2(out)

    out = self.l3(out)
    out = self.attn1(out)
    out = self.l4(out)
    out = self.attn2(out)
    out = self.last(out)
    return out

以下是辨别器的代码,同样,如果上面修改了数据集图片读入大小以及生成器生成的图片大小,不要忘记修改辨别器中的image_size,同时可能还需要修改每个卷积层的核大小以及步长。

class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, image_size = 256, ndf =64):
        super().__init__()
        def conv_2d(in_channels, out_channels, kernel_size, stride = 1, padding = 0):
            return nn.Sequential(
                spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)),
                nn.LeakyReLU(0.1)
            )
        self.block_1 = conv_2d(in_channels, ndf, 4, 2, 1)
        current_dim = ndf
        self.block_2 = conv_2d(current_dim, current_dim * 2, 4, 2, 1)
        current_dim *= 2
        self.block_3 = conv_2d(current_dim, current_dim * 2, 4, 2, 1)
        current_dim *= 2
        #self.block_5 = conv_2d(current_dim, current_dim * 2, 4, 2, 1)
       # current_dim *= 2
        #self.block_6 = conv_2d(current_dim, current_dim * 2, 4, 2, 1)
        #current_dim *= 2
        self.attn_layer_1 = selfattention(current_dim)
        self.block_4 = conv_2d(current_dim, current_dim * 2, 4, 2, 1)
        current_dim *= 2
        self.attn_layer_2 = selfattention(current_dim)

        self.last_layer = nn.Sequential(nn.Conv2d(current_dim, 1, 4, stride= 1),
                                        )

    def forward(self, input):
        all_layers = [self.block_1, self.block_2, self.block_3, self.attn_layer_1,
                          self.block_4, self.attn_layer_2,self.last_layer]
        out = reduce(lambda x, layer: layer(x), all_layers, input)  #套娃  clock3(block2(block1(x)))......返回结果

        return out

以下贴出自注意力机制代码,有疑问的同学可以参考之前的博客:

class selfattention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size = 1, stride = 1)
        self.key   = nn.Conv2d(in_channels, in_channels // 8, kernel_size = 1, stride = 1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size = 1, stride = 1)
        self.gamma = nn.Parameter(torch.zeros(1))  #gamma为一个衰减参数,由torch.zero生成,nn.Parameter的作用是将其转化成为可以训练的参数.
        self.softmax = nn.Softmax(dim = -1)
    def forward(self, input):
        batch_size, channels, height, width = input.shape
        # input: B, C, H, W -> q: B, H * W, C // 8
        q = self.query(input).view(batch_size, -1, height * width).permute(0, 2, 1)
        #input: B, C, H, W -> k: B, C // 8, H * W
        k = self.key(input).view(batch_size, -1, height * width)
        #input: B, C, H, W -> v: B, C, H * W
        v = self.value(input).view(batch_size, -1, height * width)
        #q: B, H * W, C // 8 x k: B, C // 8, H * W -> attn_matrix: B, H * W, H * W
        attn_matrix = torch.bmm(q, k)  #torch.bmm进行tensor矩阵乘法,q与k相乘得到的值为attn_matrix.
        attn_matrix = self.softmax(attn_matrix)#经过一个softmax进行缩放权重大小.
        out = torch.bmm(v, attn_matrix.permute(0, 2, 1))  #tensor.permute将矩阵的指定维进行换位.这里将1于2进行换位。
        out = out.view(*input.shape)

        return self.gamma * out + input

训练代码以及损失函数等全部源代码,请上我的GitHub获取,之后会上传预训练模型,您也可以自己进行训练~~~///(^v^)\\\~~~

数据集来自kaggle,可自己搜索:"anime"然后下载。

Github链接:https://github.com/rabbitdeng/anime-sagan-pytorch/tree/main

你可能感兴趣的:(深度学习,GAN,人工智能,深度学习,pytorch,计算机视觉,生成对抗网络)