[三十一]深度学习Pytorch-生成对抗网络GAN

0. 往期内容

[一]深度学习Pytorch-张量定义与张量创建

[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换

[三]深度学习Pytorch-张量数学运算

[四]深度学习Pytorch-线性回归

[五]深度学习Pytorch-计算图与动态图机制

[六]深度学习Pytorch-autograd与逻辑回归

[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)

[八]深度学习Pytorch-图像预处理transforms

[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)

[十]深度学习Pytorch-transforms图像操作及自定义方法

[十一]深度学习Pytorch-模型创建与nn.Module

[十二]深度学习Pytorch-模型容器与AlexNet构建

[十三]深度学习Pytorch-卷积层(1D/2D/3D卷积、卷积nn.Conv2d、转置卷积nn.ConvTranspose)

[十四]深度学习Pytorch-池化层、线性层、激活函数层

[十五]深度学习Pytorch-权值初始化

[十六]深度学习Pytorch-18种损失函数loss function

[十七]深度学习Pytorch-优化器Optimizer

[十八]深度学习Pytorch-学习率Learning Rate调整策略

[十九]深度学习Pytorch-可视化工具TensorBoard

[二十]深度学习Pytorch-Hook函数与CAM算法

[二十一]深度学习Pytorch-正则化Regularization之weight decay

[二十二]深度学习Pytorch-正则化Regularization之dropout

[二十三]深度学习Pytorch-批量归一化Batch Normalization

[二十四]深度学习Pytorch-BN、LN(Layer Normalization)、IN(Instance Normalization)、GN(Group Normalization)

[二十五]深度学习Pytorch-模型保存与加载

[二十六]深度学习Pytorch-模型微调Finetune

[二十七]深度学习Pytorch-GPU的使用

[二十八]深度学习Pytorch-图像分类Resnet18

[二十九]深度学习Pytorch-图像分割Unet

[三十]深度学习Pytorch-图像目标检测Faster RCNN

[三十一]深度学习Pytorch-生成对抗网络GAN

深度学习Pytorch-生成对抗网络GAN

  • 0. 往期内容
  • 1. 生成对抗网络GAN定义
  • 2. 如何训练GAN?
  • 3. 训练DCGAN实现人脸生成
  • 4. 完整代码

1. 生成对抗网络GAN定义

[三十一]深度学习Pytorch-生成对抗网络GAN_第1张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第2张图片

2. 如何训练GAN?

[三十一]深度学习Pytorch-生成对抗网络GAN_第3张图片
不是数值上的逼近,而是分布上的逼近。

[三十一]深度学习Pytorch-生成对抗网络GAN_第4张图片[三十一]深度学习Pytorch-生成对抗网络GAN_第5张图片

3. 训练DCGAN实现人脸生成

[三十一]深度学习Pytorch-生成对抗网络GAN_第6张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第7张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第8张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第9张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第10张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第11张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第12张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第13张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第14张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第15张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第16张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第17张图片
[三十一]深度学习Pytorch-生成对抗网络GAN_第18张图片

4. 完整代码

gan_demo.py

# -*- coding: utf-8 -*-
"""
# @file name  : gan_train.py
# @brief      : gan demo
"""
import os
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import imageio
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tools.common_tools import set_seed
from torch.utils.data import DataLoader
from tools.my_dataset import CelebADataset
from tools.dcgan import Discriminator, Generator
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed(1)  # 设置随机种子

# confg

# data_dir = os.path.join(BASE_DIR, "..", "..", "data", "img_align_celeba_2k")
data_dir = ""
out_dir = os.path.join(BASE_DIR, "..", "..", "log_gan")
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

ngpu = 0    # Number of GPUs available. Use 0 for CPU mode.
IS_PARALLEL = True if ngpu > 1 else False
checkpoint_interval = 10

image_size = 64
nc = 3
nz = 100
ngf = 128  # 64
ndf = 128   # 64
num_epochs = 20
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_idx = 1    # 0.9 可以利用平滑标签来训练,效果可能会更好
fake_idx = 0    # 0.1

lr = 0.0002
batch_size = 64
beta1 = 0.5

d_transforms = transforms.Compose([transforms.Resize(image_size),
                   transforms.CenterCrop(image_size),
                   transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 将数据尺度变换到-1 ,1之间
               ])
if __name__ == '__main__':
    # step 1: data

    train_set = CelebADataset(data_dir=data_dir, transforms=d_transforms)
    train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=2, shuffle=True)

    # show train img
    flag = 0
    # flag = 1
    if flag:
        img_bchw = next(iter(train_loader))
        plt.title("Training Images")
        plt.imshow(np.transpose(vutils.make_grid(img_bchw.to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
        plt.show()
        plt.close()

    # step 2: model
    net_g = Generator(nz=nz, ngf=ngf, nc=nc) #nz输入向量的长度,ngf是generator特征图的通道数,nc是最终输出的向量的通道数
    net_g.initialize_weights()

    net_d = Discriminator(nc=nc, ndf=ndf)
    net_d.initialize_weights()

    net_g.to(device)
    net_d.to(device)

    if IS_PARALLEL and torch.cuda.device_count() > 1:
        net_g = nn.DataParallel(net_g)
        net_d = nn.DataParallel(net_d)

    # step 3: loss
    criterion = nn.BCELoss()

    # step 4: optimizer
    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(net_d.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(net_g.parameters(), lr=lr, betas=(beta1, 0.999))

    lr_scheduler_d = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=8, gamma=0.1)
    lr_scheduler_g = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=8, gamma=0.1)

    # step 5: iteration
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    for epoch in range(num_epochs):
        for i, data in enumerate(train_loader):

            ############################
            # (1) Update D network
            ###########################

            net_d.zero_grad()

            # create training data
            real_img = data.to(device)
            b_size = real_img.size(0)
            real_label = torch.full((b_size,), real_idx, device=device)

            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake_img = net_g(noise)
            fake_label = torch.full((b_size,), fake_idx, device=device)

            # train D with real img
            out_d_real = net_d(real_img)
            loss_d_real = criterion(out_d_real.view(-1), real_label)

            # train D with fake img
            out_d_fake = net_d(fake_img.detach())
            loss_d_fake = criterion(out_d_fake.view(-1), fake_label)

            # backward
            loss_d_real.backward()
            loss_d_fake.backward()
            loss_d = loss_d_real + loss_d_fake

            # Update D
            optimizerD.step()

            # record probability
            d_x = out_d_real.mean().item()      # D(x)
            d_g_z1 = out_d_fake.mean().item()   # D(G(z1))

            ############################
            # (2) Update G network
            ###########################
            net_g.zero_grad()

            label_for_train_g = real_label  # 1
            out_d_fake_2 = net_d(fake_img)

            loss_g = criterion(out_d_fake_2.view(-1), label_for_train_g)
            loss_g.backward()
            optimizerG.step() 

            # record probability
            d_g_z2 = out_d_fake_2.mean().item()  # D(G(z2))

            # Output training stats
            if i % 10 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(train_loader),
                         loss_d.item(), loss_g.item(), d_x, d_g_z1, d_g_z2))

            # Save Losses for plotting later
            G_losses.append(loss_g.item())
            D_losses.append(loss_d.item())

        lr_scheduler_d.step()
        lr_scheduler_g.step()

        # Check how the generator is doing by saving G's output on fixed_noise
        with torch.no_grad():
            fake = net_g(fixed_noise).detach().cpu()
        img_grid = vutils.make_grid(fake, padding=2, normalize=True).numpy()
        img_grid = np.transpose(img_grid, (1, 2, 0))
        plt.imshow(img_grid)
        plt.title("Epoch:{}".format(epoch))
        # plt.show()
        plt.savefig(os.path.join(out_dir, "{}_epoch.png".format(epoch)))

        # checkpoint
        if (epoch+1) % checkpoint_interval == 0:

            checkpoint = {"g_model_state_dict": net_g.state_dict(),
                          "d_model_state_dict": net_d.state_dict(),
                          "epoch": epoch}
            path_checkpoint = os.path.join(out_dir, "checkpoint_{}_epoch.pkl".format(epoch))
            torch.save(checkpoint, path_checkpoint)

    # plot loss
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    # plt.show()
    plt.savefig(os.path.join(out_dir, "loss.png"))

    # save gif
    imgs_epoch = [int(name.split("_")[0]) for name in list(filter(lambda x: x.endswith("epoch.png"), os.listdir(out_dir)))]
    imgs_epoch = sorted(imgs_epoch)

    imgs = list()
    for i in range(len(imgs_epoch)):
        img_name = os.path.join(out_dir, "{}_epoch.png".format(imgs_epoch[i]))
        imgs.append(imageio.imread(img_name))

    imageio.mimsave(os.path.join(out_dir, "generation_animation.gif"), imgs, fps=2)

    print("done")

gan_inference.py

# -*- coding: utf-8 -*-
"""
# @file name  : gan_inference.py
# @brief      : gan inference
"""
import os
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import imageio
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tools.common_tools import set_seed
from torch.utils.data import DataLoader
from tools.my_dataset import CelebADataset
from tools.dcgan import Discriminator, Generator
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def remove_module(state_dict_g):
    # remove module.
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict_g.items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    return new_state_dict

set_seed(1)  # 设置随机种子

# config
path_checkpoint = os.path.join(BASE_DIR, "checkpoint_14_epoch.pkl")
image_size = 64
num_img = 64
nc = 3
nz = 100
ngf = 128
ndf = 128

d_transforms = transforms.Compose([transforms.Resize(image_size),
                   transforms.CenterCrop(image_size),
                   transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
               ])

# step 1: data
fixed_noise = torch.randn(num_img, nz, 1, 1, device=device)

flag = 0
# flag = 1
if flag:
    z_idx = 0
    single_noise = torch.randn(1, nz, 1, 1, device=device)
    for i in range(num_img):
        add_noise = single_noise
        add_noise = add_noise[0, z_idx, 0, 0] + i*0.01
        fixed_noise[i, ...] = add_noise


# step 2: model
net_g = Generator(nz=nz, ngf=ngf, nc=nc)
# net_d = Discriminator(nc=nc, ndf=ndf) #判别器使用过程中不需要
checkpoint = torch.load(path_checkpoint, map_location="cpu")

state_dict_g = checkpoint["g_model_state_dict"]
state_dict_g = remove_module(state_dict_g)
net_g.load_state_dict(state_dict_g)
net_g.to(device)
# net_d.load_state_dict(checkpoint["d_model_state_dict"])
# net_d.to(device)

# step3: inference
with torch.no_grad():
    fake_data = net_g(fixed_noise).detach().cpu()
img_grid = vutils.make_grid(fake_data, padding=2, normalize=True).numpy()
img_grid = np.transpose(img_grid, (1, 2, 0))
plt.imshow(img_grid)
plt.show()

tools/dcgan.py

# -*- coding: utf-8 -*-
"""
# @file name  : dcgan.py
# @brief      : deep convolutional generative adversarial networks
"""


from collections import OrderedDict
import torch
import torch.nn as nn


class Generator(nn.Module):
    def __init__(self, nz=100, ngf=128, nc=3): #nz输入向量的长度,ngf是generator特征图的通道数,nc是最终输出的向量的通道数
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)


class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=128):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)

tools/dataset.py

# -*- coding: utf-8 -*-
"""
# @file name  : dataset.py
# @brief      : 各数据集的Dataset定义
"""
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"1": 0, "100": 1}


class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info


class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = {"ants": 0, "bees": 1}
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.data_info)

    def get_img_info(self, data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = self.label_name[sub_dir]
                    data_info.append((path_img, int(label)))

        if len(data_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
        return data_info


class PortraitDataset(Dataset):
    def __init__(self, data_dir, transform=None, in_size = 224):
        super(PortraitDataset, self).__init__()
        self.data_dir = data_dir
        self.transform = transform
        self.label_path_list = list()
        self.in_size = in_size

        # 获取mask的path
        self._get_img_path()

    def __getitem__(self, index):

        path_label = self.label_path_list[index]
        path_img = path_label[:-10] + ".png"

        img_pil = Image.open(path_img).convert('RGB')
        img_pil = img_pil.resize((self.in_size, self.in_size), Image.BILINEAR)
        img_hwc = np.array(img_pil)
        img_chw = img_hwc.transpose((2, 0, 1))

        label_pil = Image.open(path_label).convert('L')
        label_pil = label_pil.resize((self.in_size, self.in_size), Image.NEAREST)
        label_hw = np.array(label_pil)
        label_chw = label_hw[np.newaxis, :, :]
        label_hw[label_hw != 0] = 1

        if self.transform is not None:
            img_chw_tensor = torch.from_numpy(self.transform(img_chw.numpy())).float()
            label_chw_tensor = torch.from_numpy(self.transform(label_chw.numpy())).float()
        else:
            img_chw_tensor = torch.from_numpy(img_chw).float()
            label_chw_tensor = torch.from_numpy(label_chw).float()

        return img_chw_tensor, label_chw_tensor

    def __len__(self):
        return len(self.label_path_list)

    def _get_img_path(self):
        file_list = os.listdir(self.data_dir)
        file_list = list(filter(lambda x: x.endswith("_matte.png"), file_list))
        path_list = [os.path.join(self.data_dir, name) for name in file_list]
        random.shuffle(path_list)
        if len(path_list) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))
        self.label_path_list = path_list


class PennFudanDataset(object):
    def __init__(self, data_dir, transforms):

        self.data_dir = data_dir
        self.transforms = transforms
        self.img_dir = os.path.join(data_dir, "PNGImages")
        self.txt_dir = os.path.join(data_dir, "Annotation")
        self.names = [name[:-4] for name in list(filter(lambda x: x.endswith(".png"), os.listdir(self.img_dir)))]

    def __getitem__(self, index):
        """
        返回img和target
        :param idx:
        :return:
        """

        name = self.names[index]
        path_img = os.path.join(self.img_dir, name + ".png")
        path_txt = os.path.join(self.txt_dir, name + ".txt")

        # load img
        img = Image.open(path_img).convert("RGB")

        # load boxes and label
        f = open(path_txt, "r")
        import re
        points = [re.findall(r"\d+", line) for line in f.readlines() if "Xmin" in line]
        boxes_list = list()
        for point in points:
            box = [int(p) for p in point]
            boxes_list.append(box[-4:])
        boxes = torch.tensor(boxes_list, dtype=torch.float)
        labels = torch.ones((boxes.shape[0],), dtype=torch.long)

        # iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        # target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        if len(self.names) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))
        return len(self.names)


class CelebADataset(object):
    def __init__(self, data_dir, transforms):

        self.data_dir = data_dir
        self.transform = transforms
        self.img_names = [name for name in list(filter(lambda x: x.endswith(".jpg"), os.listdir(self.data_dir)))]

    def __getitem__(self, index):
        path_img = os.path.join(self.data_dir, self.img_names[index])
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        if len(self.img_names) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))
        return len(self.img_names)

你可能感兴趣的:(深度学习Pyrotch,pytorch,深度学习,python,机器学习,人工智能)