图像生成——使用DCGAN生成卡通肖像

一、需要注意的几点:

1、生成器的网络和判别器的网络均不含池化层。

2、判别器的最后一层网络输出使用sigmoid激活,生成器的最后一层网络输出使用tanh激活。

3、生成器和判别器的网络结果呈对称形式如:生成器的第一层的卷积核大小,步长,输入通道,输出通道核判别器的最后一层卷积核大小,步长一致,输出通道,输入通道大小一致。

图像生成——使用DCGAN生成卡通肖像_第1张图片

(上图所示的是生成器,判别器的网络刚好对称,从后往前)

4、卷积核使用偶数大小的效果比使用奇数大小的卷积核效果好。

5、使用转置卷积进行上采样。

6、训练是可以每训练两轮生成器训练一次判别器(原因是判别器能力优于生成 器)。

二、代码部分:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from torchvision.utils import save_image
import numpy as np
import os
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

class Sampling_data(Dataset):
    def __init__(self,img_path):
        self.file_names = []
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
        ])

        for file in os.listdir(img_path):
            file_name = os.path.join(img_path,file)
            self.file_names.append(file_name)

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, item):
        file = self.file_names[item]
        img_array = Image.open(file)
        xs = self.transform(img_array)
        return xs

class Dnet(nn.Module):
    def __init__(self):
        super(Dnet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5,stride=3,padding=1,bias=False),
            # nn.BatchNorm2d(64)
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2,padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.conv4(y)
        y = self.conv5(y)
        return y

class Gnet(nn.Module):
    def __init__(self):
        super(Gnet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.conv5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.conv4(y)
        y = self.conv5(y)
        return y



if __name__ == '__main__':
    save_params_path = r"params"
    save_img_path = r"./img"
    batchsize = 100
    img_data = r"E:\Learnn\cartoonfaces"            #img_data为卡通人物的路径
    num_epoch = 500
    random_num = 128

    save_real_img_path = os.path.join(save_img_path,"real_img")
    save_fake_img_path = os.path.join(save_img_path,"fake_img")

    save_dparam_path = os.path.join(save_params_path, "d_self_net.pth")
    save_gparam_path = os.path.join(save_params_path, "g_self_net.pth")

    for path in [save_img_path, save_params_path, save_real_img_path, save_fake_img_path]:
        if not os.path.exists(path):
            os.mkdir(path)

    data_loader = DataLoader(Sampling_data(img_data), batch_size=batchsize, shuffle=True, num_workers=4, drop_last=True)
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    g_net = Gnet().to(device)
    d_net = Dnet().to(device)
    g_net.train()
    d_net.train()

    if os.path.exists(save_dparam_path and save_gparam_path):  # 两个网络两个参数
        d_net.load_state_dict(torch.load(save_dparam_path))
        g_net.load_state_dict(torch.load(save_gparam_path))
        print("两个参数已经加载成功!!!")
    else:
        print("NO Params!!!")

    loss_fn = nn.BCELoss()
    d_optimizer = torch.optim.Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
    g_optimizer = torch.optim.Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
    for epoch in range(num_epoch):
        for i, img in enumerate(data_loader):
            real_img = img.to(device)
            real_label = torch.ones(batchsize).view(-1, 1, 1, 1).to(device)
            fake_label = torch.zeros(batchsize).view(-1, 1, 1, 1).to(device)

            real_out = d_net(real_img)
            d_loss_real = loss_fn(real_out, real_label)

            rand_n = torch.randn(batchsize, random_num, 1, 1).to(device=device)

            fake_img = g_net(rand_n)

            fake_out = d_net(fake_img)
            d_loss_fake = loss_fn(fake_out, fake_label)

            d_loss = d_loss_real + d_loss_fake
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            rand_n1 = torch.randn(batchsize, random_num, 1, 1).to(device=device)
            fake_img = g_net(rand_n1)
            output = d_net(fake_img)
            g_loss = loss_fn(output, real_label)

            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            if i%10 == 0:
                print(real_out.data.mean(),fake_out.data.mean())


        fake_imgs = (0.5*(fake_img.cpu().data+1)).clamp(0, 1)
        real_imgs = (0.5 * (real_img.cpu().data + 1)).clamp(0, 1)

        save_image(fake_imgs, os.path.join(save_fake_img_path,"{}_fake_imgs.jpg".format(epoch+1)),nrow=10,normalize=True,scale_each=True)
        save_image(real_imgs, os.path.join(save_real_img_path,"{}_real_imgs.jpg".format(epoch+1)),nrow=10,normalize=True,scale_each=True)

        torch.save(g_net.state_dict(), save_gparam_path)
        torch.save(d_net.state_dict(), save_dparam_path)

三、效果展示:

1)生成的图像:

图像生成——使用DCGAN生成卡通肖像_第2张图片

2)原始图像:

图像生成——使用DCGAN生成卡通肖像_第3张图片

 

你可能感兴趣的:(深度学习,神经网络,pytorch,卷积神经网络)