生成对抗网络DCGAN

时间20210504
作者:知道许多的橘子
实现:生成对抗网络DCGAN_on_MNIST
如果感觉算力不够用了,或者心疼自己电脑了!
可以用我实验室的算力,试试呢!
害,谁叫我的算力都用不完呢!
支持所有框架!实际上框架都配置好了!
傻瓜式云计算!
Tesla v100 1卡,2卡,4卡,8卡
内存16-128G
cpu:8-24核
想要?加个微信:15615634293
欢迎打扰!

生成对抗网络DCGAN_第1张图片

import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch import optim
import os

# 设置超参数
batch_size = 64
learning_rate = 0.0002
epochsize = 60
sample_dir = "images_3"

# 创建生成图像的目录
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)


# 生成器结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 全连接层
        self.fc_layer = nn.Sequential(nn.Linear(100, 128 * 7 * 7))  # torch.Size([64, 6272])

        # 卷积层
        self.conv_layer = nn.Sequential(

            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),  # torch.Size([64, 128, 14, 14])

            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),  
            
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),  # torch.Size([64, 128, 28, 28])

            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),  # torch.Size([64, 64, 28, 28])

            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()  # torch.Size([64, 1, 28, 28])
        )

    def forward(self, input):
        x = self.fc_layer(input)  # # torch.Size([64, 6272])
        x = x.view(input.shape[0], 128, 7, 7)  # torch.Size([64, 128, 7, 7])
        x = self.conv_layer(x)
        return x


# 鉴别器结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # 卷积层
        self.conv_layer = nn.Sequential(

            # 第一层卷积层没有使用bn层,其余三层均有使用
            nn.Conv2d(1, 16, 3, 2, 1),  # torch.Size([64, 16, 14, 14])
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),

            nn.Conv2d(16, 32, 3, 2, 1),  # torch.Size([64, 32, 7, 7])
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, momentum=0.8),

            nn.Conv2d(32, 64, 3, 2, 1),  # torch.Size([64, 64, 4, 4])
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, momentum=0.8),

            nn.Conv2d(64, 128, 3, 2, 1),  # torch.Size([64, 128, 2, 2])
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, momentum=0.8)
        )

        # 全连接层
        self.fc_layer = nn.Sequential(
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        x = self.conv_layer(input)  # torch.Size([64, 128, 2, 2])
        x = x.view(input.shape[0], -1)  # torch.Size([64, 512])
        x = self.fc_layer(x)  # torch.Size([64, 1])

        return x


# 训练集下载
mnist_traindata = datasets.MNIST('/home/megstudio/dataset/dataset-2105/file-1258/mnist', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],
                         std=[0.5])
]), download=False)
mnist_train = DataLoader(mnist_traindata, batch_size=batch_size, shuffle=True)

# GPU加速
device = torch.device('cuda')

G = Generator().to(device)
D = Discriminator().to(device)

# 导入之前的训练模型
# G.load_state_dict(torch.load('G_plus.ckpt'))
# D.load_state_dict(torch.load('D_plus.ckpt'))

# 设置优化器与损失函数,二分类的时候使用BCELoss较好,BCEWithLogitsLoss是自带一层Sigmoid
# criteon = nn.BCEWithLogitsLoss()
criteon = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate)

# 设置对比标签
# realimage_label = torch.ones(batch_size, 1).to(device)  # value:1 torch.Size([128, 1])
# fakeimage_label = torch.zeros(batch_size, 1).to(device)  # value:0 torch.Size([128, 1])

# 开始训练
print("start training")
for epoch in range(epochsize):

    D_loss_total = 0
    G_loss_total = 0
    total_num = 0

    # 这里的RealImageLabel是没有用上的
    for batchidx, (realimage, _) in enumerate(mnist_train):

        realimage = realimage.to(device)
        realimage_label = torch.ones(realimage.size(0), 1).to(device)  # value:1 torch.Size([128, 1])
        fakeimage_label = torch.zeros(realimage.size(0), 1).to(device)  # value:0 torch.Size([128, 1])

        # 随机生成噪音
        z = torch.randn(realimage.size(0), 100).to(device)

        #==================================================
        # 训练鉴别器————总的损失为两者相加
        d_realimage_loss = criteon(D(realimage), realimage_label) 
        d_fakeimage_loss = criteon(D(G(z)), fakeimage_label)  
        D_loss = d_realimage_loss + d_fakeimage_loss

        # 参数训练三个步骤
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # 计算一次epoch的总损失
        D_loss_total += D_loss

        #==================================================
        # 训练生成器————损失只有一个
        G_loss = criteon(D(G(z)), realimage_label)

        # 参数训练三个步骤
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # 计算一次epoch的总损失
        G_loss_total += G_loss

        # 计算训练图像个数
        total_num += realimage.size(0)

        # 打印相关的loss值
        if batchidx % 300 == 0:
            print("batchidx:{}/{}, D_loss:{}, G_loss:{}, total_num:{},".format(batchidx, len(mnist_train), D_loss, G_loss,
                                                                             total_num, ))

    # 打印一次训练的loss值
    print('Epoch:{}/{}, D_loss:{}, G_loss:{}, total_num:{}'.format(epoch, epochsize, D_loss_total / len(mnist_train),
                                                                   G_loss_total / len(mnist_train), total_num))

    # 保存生成图像
    z = torch.randn(batch_size, 100).to(device)
    save_image(G(z).data[:64], os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 14)), nrow=8,
               normalize=True)

    # 保存网络结构
    torch.save(G.state_dict(), 'G_plus.ckpt')
    torch.save(D.state_dict(), 'D_plus.ckpt')

你可能感兴趣的:(笔记,深度学习,pytorch,神经网络)