batch_size = 32
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,),
std=(0.5,))
])
mnist_data = torchvision.datasets.MNIST("./mnist_data",train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=mnist_data,
batch_size=batch_size,
shuffle=True)
image_size = 784
hidden_size = 256
D = nn.Sequential(
nn.Linear(image_size,hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,1),
nn.Sigmoid()
)
latent_size = 64
G = nn.Sequential(
nn.Linear(latent_size,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh()
)
D = D.to(device)
G = G.to(device)
loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
total_step = len(dataloader)
print(total_step)
num_epochs = 200
for epoch in range(num_epochs):
for i,(images,_) in enumerate(dataloader):
batch_size = images.size(0)
images = images.reshape(batch_size,image_size).to(device)
real_labels = torch.ones(batch_size,1).to(device)
fake_labels = torch.zeros(batch_size,1).to(device)
outputs = D(images)
d_loss_real = loss_fn(outputs,real_labels)
real_score = outputs
z = torch.randn(batch_size,latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images.detach())
d_loss_fake = loss_fn(outputs,fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
z = torch.randn(batch_size,latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
g_loss = loss_fn(outputs,real_labels)
reset_grad()
g_loss.backward()
g_optimizer.step()
if i % 1000 == 0:
print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}"
.format(epoch, num_epochs, i, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))
data:image/s3,"s3://crabby-images/7357e/7357eb748e126ae550fe7e6c29cff2fefc9038a1" alt="pytorch学习11-GAN_第1张图片"
z = torch.randn(1,latent_size).to(device)
fake_images = G(z).view(28,28).data.cpu().numpy()
plt.imshow(fake_images)
data:image/s3,"s3://crabby-images/1e91f/1e91fef62b6677ec032330501db5466a15e1870b" alt="pytorch学习11-GAN_第2张图片"