DeepLearing—CV系列(二十二)——DCGAN生成动漫卡通人脸的Pytorch实现

文章目录

  • 一、cfg.py
  • 二、D_Net.py
  • 三、G_Net.py
  • 四、Mydataset.py
  • 五、Train.py
  • 六、效果展示

完整代码获取:
链接:https://pan.baidu.com/s/1NfqTyxYw6FjNsrYCTj8n4w
提取码:3y2o

代码目录:
DeepLearing—CV系列(二十二)——DCGAN生成动漫卡通人脸的Pytorch实现_第1张图片

一、cfg.py

#opt参数

ngf = 96
ndf = 96
nz = 100
img_size = 96
batch_size = 100
num_workers = 4
lr1 = 0.0002
lr2 = 0.0002
beta1 =0.5
epochs = 200
d_every = 1
g_every = 5
save_every = 20
data_path = r"F:\GAN"

二、D_Net.py

import opt
import torch.nn as nn


class NetD(nn.Module):
    def __init__(self):
        super().__init__()
        ndf = opt.ndf
        self.conv_layer = nn.Sequential(
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.conv_layer(x)

三、G_Net.py

import opt
import torch.nn as nn


class NetG(nn.Module):
    def __init__(self):
        super().__init__()
        ngf = opt.ngf
        self.conv_layer = nn.Sequential(
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.conv_layer(x)

四、Mydataset.py

from torch.utils.data import Dataset,DataLoader
import os
from PIL import Image
import numpy as np
import torch

class MyDataset(Dataset):
    mean = [0.6712, 0.5770, 0.5549]
    std = [0.2835, 0.2785, 0.2641]
    def __init__(self,path):
        self.path = path
        self.dataset = os.listdir(path)
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        name = self.dataset[index]
        img = Image.open(os.path.join(self.path, name))
        img = np.array(img) / 255.
        img = (img - MyDataset.mean) / MyDataset.std
        img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
        return img

if __name__ == '__main__':

    imagelist = os.listdir(r"F:\GAN\faces")

    data = MyDataset(r"F:\GAN\faces")
    loader = DataLoader(dataset=data,batch_size=51223,shuffle=True)
    data= next(iter(loader))
    mean = torch.mean(data, dim=(0,2,3))
    std = torch.std(data, dim=(0,2,3))
    print(mean ,std)

五、Train.py

import torch
from torchvision import transforms
import torchvision
import opt
import torch.utils.data as data
from D_Net import NetD
from G_Net import NetG
import torch.nn as nn
import os
from torchvision.utils import save_image

class Trainer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5,],[0.5,0.5,0.5,])
        ])
        self.loss_fn = nn.BCELoss()
    def train(self):
        dataset = torchvision.datasets.ImageFolder(opt.data_path,transform=self.trans)
        dataloader = data.DataLoader(dataset=dataset,batch_size=opt.batch_size,shuffle=True)
        d_net = NetD().to(self.device)
        g_net = NetG().to(self.device)
        if os.path.exists("dcgan_params"):
            # torch.nn.DataParallel(d_net)
            d_net.load_state_dict(torch.load("dcgan_params/d_net.pth"))
        else:
            print("NO d_net Param")

        if os.path.exists("dcgan_params"):
            # torch.nn.DataParallel(g_net)
            g_net.load_state_dict(torch.load("dcgan_params/g_net.pth"))
        else:
            print("NO g_net Param")
        D_optimizer = torch.optim.Adam(d_net.parameters(),lr=opt.lr1,betas=(opt.beta1,0.999))
        G_optimizer = torch.optim.Adam(g_net.parameters(),lr=opt.lr1,betas=(opt.beta1,0.999))
        NUM_EPOHS = opt.epochs
        for epoh in range(NUM_EPOHS):
            for i,(images,_) in enumerate(dataloader):
                N = images.size(0)
                images = images.to(self.device)
                real_labels = torch.ones(N,1,1,1).to(self.device)
                fake_labels = torch.zeros(N,1,1,1).to(self.device)
                real_out = d_net(images)
                d_real_loss = self.loss_fn(real_out,real_labels)

                z = torch.randn(N,100,1,1).to(self.device)
                fake_img = g_net(z)
                fake_out = d_net(fake_img)
                d_fake_loss = self.loss_fn(fake_out,fake_labels)
                d_loss = d_fake_loss+d_real_loss

                D_optimizer.zero_grad()
                d_loss.backward()
                D_optimizer.step()

                z = torch.randn(N,100,1,1).to(self.device)
                fake_img = g_net(z)
                fake_out = d_net(fake_img)
                g_loss = self.loss_fn(fake_out,real_labels)
                G_optimizer.zero_grad()
                g_loss.backward()
                G_optimizer.step()
                if i % 100 == 0:
                    print("Epoch:{}/{},d_loss:{:.3f},g_loss:{:.3f},"
                          "d_real:{:.3f},d_fake:{:.3f}".
                          format(epoh, NUM_EPOHS, d_loss.item(), g_loss.item(),
                                 real_out.data.mean(), fake_out.data.mean()))
                    if not os.path.exists("./dcgan_img"):
                        os.mkdir("./dcgan_img")
                    if not os.path.exists("./dcgan_params"):
                        os.mkdir("./dcgan_params")
                    real_image = images.cpu().data
                    save_image(real_image, "./dcgan_img/epoch{}-iteration{}-real_img.jpg".
                               format(epoh , i), nrow=10, normalize=True, scale_each=True)
                    fake_image = fake_img.cpu().data
                    save_image(fake_image, "./dcgan_img/epoch{}-iteration{}-fake_img.jpg".
                               format(epoh , i), nrow=10, normalize=True, scale_each=True)
                    torch.save(d_net.state_dict(), "dcgan_params/d_net.pth")
                    torch.save(g_net.state_dict(), "dcgan_params/g_net.pth")
if __name__ == '__main__':

    t = Trainer()
    t.train()

六、效果展示

50轮左右的效果:
DeepLearing—CV系列(二十二)——DCGAN生成动漫卡通人脸的Pytorch实现_第2张图片
DeepLearing—CV系列(二十二)——DCGAN生成动漫卡通人脸的Pytorch实现_第3张图片

你可能感兴趣的:(深度学习,AI,GAN,cv,pytorch,神经网络,深度学习,对抗式生成神经网络)