时间20210504
作者:知道许多的橘子
实现:生成对抗网络DCGAN_on_MNIST
如果感觉算力不够用了,或者心疼自己电脑了!
可以用我实验室的算力,试试呢!
害,谁叫我的算力都用不完呢!
支持所有框架!实际上框架都配置好了!
傻瓜式云计算!
Tesla v100 1卡,2卡,4卡,8卡
内存16-128G
cpu:8-24核
想要?加个微信:15615634293
欢迎打扰!
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch import optim
import os
batch_size = 64
learning_rate = 0.0002
epochsize = 60
sample_dir = "images_3"
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc_layer = nn.Sequential(nn.Linear(100, 128 * 7 * 7))
self.conv_layer = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, 3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, input):
x = self.fc_layer(input)
x = x.view(input.shape[0], 128, 7, 7)
x = self.conv_layer(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv_layer = nn.Sequential(
nn.Conv2d(1, 16, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.Conv2d(16, 32, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.BatchNorm2d(32, momentum=0.8),
nn.Conv2d(32, 64, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.BatchNorm2d(64, momentum=0.8),
nn.Conv2d(64, 128, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.BatchNorm2d(128, momentum=0.8)
)
self.fc_layer = nn.Sequential(
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, input):
x = self.conv_layer(input)
x = x.view(input.shape[0], -1)
x = self.fc_layer(x)
return x
mnist_traindata = datasets.MNIST('/home/megstudio/dataset/dataset-2105/file-1258/mnist', train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],
std=[0.5])
]), download=False)
mnist_train = DataLoader(mnist_traindata, batch_size=batch_size, shuffle=True)
device = torch.device('cuda')
G = Generator().to(device)
D = Discriminator().to(device)
criteon = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
print("start training")
for epoch in range(epochsize):
D_loss_total = 0
G_loss_total = 0
total_num = 0
for batchidx, (realimage, _) in enumerate(mnist_train):
realimage = realimage.to(device)
realimage_label = torch.ones(realimage.size(0), 1).to(device)
fakeimage_label = torch.zeros(realimage.size(0), 1).to(device)
z = torch.randn(realimage.size(0), 100).to(device)
d_realimage_loss = criteon(D(realimage), realimage_label)
d_fakeimage_loss = criteon(D(G(z)), fakeimage_label)
D_loss = d_realimage_loss + d_fakeimage_loss
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
D_loss_total += D_loss
G_loss = criteon(D(G(z)), realimage_label)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
G_loss_total += G_loss
total_num += realimage.size(0)
if batchidx % 300 == 0:
print("batchidx:{}/{}, D_loss:{}, G_loss:{}, total_num:{},".format(batchidx, len(mnist_train), D_loss, G_loss,
total_num, ))
print('Epoch:{}/{}, D_loss:{}, G_loss:{}, total_num:{}'.format(epoch, epochsize, D_loss_total / len(mnist_train),
G_loss_total / len(mnist_train), total_num))
z = torch.randn(batch_size, 100).to(device)
save_image(G(z).data[:64], os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 14)), nrow=8,
normalize=True)
torch.save(G.state_dict(), 'G_plus.ckpt')
torch.save(D.state_dict(), 'D_plus.ckpt')