在上一篇博文也就是对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (一)中,有关github的代码、注释和计算流程图已经贴出,但上述代码适用于图像识别领域的“Hello World!”——mnist数据集,后来我根据自己实验的需要对代码进行了一些改动:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
# from tensorflow.examples.tutorials.mnist import input_data
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import scipy.io as sio
# 根据自己制作的图像列表的txt文档加载自己的数据集
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(torch.utils.data.Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
print(txt)
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split('\t')
imgs.append((words[0], words[1]))
print(imgs)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
else:
img = Tensor.from_numpy(img)
return img, label
def __len__(self):
return len(self.imgs)
##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水
transform = transforms.Compose([transforms.Scale((150, 150)), transforms.ToTensor()]) # 转换为张量
train_txt_path = '.txt' # 自己的数据集的位置列表txt
trainset = MyDataset(txt=train_txt_path, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=mb_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=1, shuffle=False)
##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水
再次附上github代码链接:https://github.com/wiseodd/generative-models
def log(x):
return torch.log(x + 1e-8)
# Encoder: q(z|x,eps) # 编码器
Q = torch.nn.Sequential(
torch.nn.Linear(X_dim + eps_dim, h_dim), # 一个全连接层
torch.nn.ReLU(),
torch.nn.Linear(h_dim, z_dim) # 一个全连接层
)
# Decoder: p(x|z) # 解码器
P = torch.nn.Sequential(
torch.nn.Linear(z_dim, h_dim), # 一个全连接层
torch.nn.ReLU(),
torch.nn.Linear(h_dim, X_dim), # 一个全连接层
torch.nn.Sigmoid()
)
# Discriminator: T(X, z) # 判别器
T = torch.nn.Sequential(
torch.nn.Linear(X_dim + z_dim, h_dim), # 一个全连接层
torch.nn.ReLU(),
torch.nn.Linear(h_dim, 1) # 一个全连接层 # 输出为一维,即一个数
)
Q.cuda()
P.cuda()
T.cuda()
def reset_grad(): # 重置梯度为0
Q.zero_grad()
P.zero_grad()
T.zero_grad()
Q_solver = optim.Adam(Q.parameters(), lr=lr) # 三个模块的优化求解器
P_solver = optim.Adam(P.parameters(), lr=lr)
T_solver = optim.Adam(T.parameters(), lr=lr)
for it in range(1000000): # 开始迭代
print(it)
# X = sample_X(mb_size) # 输入为从训练集中采样并进行类型转换后的数据
for i, (X, _) in enumerate(train_loader):
X = X.view(-1, 150 * 150 * 3)
X = Variable(X)
eps = Variable(torch.randn(mb_size, eps_dim))
z = Variable(torch.randn(mb_size, z_dim)) # 由标准正态分布(均值为0,方差为1)中随机采样
# Optimize VAE # 优化变分自编码器
# z_sample = Q(torch.cat([X, eps], 1)) # 按列拼接,需要维度一致方能行对齐
z_sample = Q(torch.cat([X, eps], 1).cuda()) # 按列拼接,需要维度一致方能行对齐
# X_sample = P(z_sample)
X_sample = P(z_sample.cuda())
# T_sample = T(torch.cat([X, z_sample], 1))
T_sample = T(torch.cat([X, z_sample.cpu()], 1).cuda())
disc = torch.mean(-T_sample) # 判别器输出的负数的均值
loglike = -nn.binary_cross_entropy(X_sample, X.cuda(), size_average=False) / mb_size
# 交叉熵, 最小化交叉熵损失函数等价于最大化对数似然, 让重构图像尽可能接近原始输入图像
elbo = -(disc + loglike) # 证据下界, 常用在变分推断中
elbo.backward() # 证据下界反向传播,优化编码器与解码器
Q_solver.step()
P_solver.step()
reset_grad() # 重置梯度为0
# Discriminator T(X, z) # 对于判别器,优化判别器
# z_sample = Q(torch.cat([X, eps], 1)) # z_sample是输入经过编码器后的输出
z_sample = Q(torch.cat([X, eps], 1).cuda()) # z_sample是输入经过编码器后的输出
T_q = nn.sigmoid(T(torch.cat([X, z_sample.cpu()], 1).cuda()))
T_prior = nn.sigmoid(T(torch.cat([X, z], 1).cuda()))
T_loss = -torch.mean(log(T_q) + log(1. - T_prior))
T_loss.backward()
T_solver.step()
reset_grad() # 重置梯度为0
if (it + 1) % 10 == 0:
print('Iter-{}; ELBO: {:.4}; T_loss: {:.4}'
.format(it, -elbo.data[0], -T_loss.data[0]))
# Print and plot every now and then
if (it + 1) % 10 == 0:
for k, (X, _) in enumerate(test_loader):
if k < 3:
X = X.view(-1, 150 * 150 * 3)
X = Variable(X)
eps = Variable(torch.randn(1, eps_dim))
z = Variable(torch.randn(1, z_dim)) # 由标准正态分布(均值为0,方差为1)中随机采样
z_sample = Q(torch.cat([X, eps], 1).cuda()) # 按列拼接,需要维度一致方能行对齐
X_random = P(z.cuda()).data.cpu().numpy()
reconst = P(z_sample).data.cpu().numpy() # 原始输入输入解码器后的输出, 取前16个作为示范
X = X.numpy()
reconst1 = reconst.reshape(3, 150 * 150)
reconst2 = reconst1[0, :]
reconst3 = np.zeros(((3, 150, 150)))
reconst3 = np.array(reconst3)
reconst3[0, :, :] = reconst1[0, :].reshape(150, 150)
reconst3[1, :, :] = reconst1[1, :].reshape(150, 150)
reconst3[2, :, :] = reconst1[2, :].reshape(150, 150)
reconst = reconst3
# save_image(torch.from_numpy(sample), 'try/reconst_iter_' + str(it) + '_' + str(k) + '.png')
x1 = X.reshape(3, 150 * 150)
x2 = x1[0, :]
x3 = np.zeros(((3, 150, 150)))
x3 = np.array(x3)
x3[0, :, :] = x1[0, :].reshape(150, 150)
x3[1, :, :] = x1[1, :].reshape(150, 150)
x3[2, :, :] = x1[2, :].reshape(150, 150)
xx = x3
# save_image(torch.from_numpy(xx), 'try/input_' + str(it) + '_' + str(k) + '.png')
image_show = np.concatenate((xx, reconst), axis=2)
image_show = image_show[np.newaxis, :, :, :]
save_image(torch.from_numpy(image_show), 'try/compare_' + str(it) + '_' + str(k) + '.png')
random1 = X_random.reshape(3, 150 * 150)
random2 = random1[0, :]
random3 = np.zeros(((3, 150, 150)))
random3 = np.array(random3)
random3[0, :, :] = random1[0, :].reshape(150, 150)
random3[1, :, :] = random1[1, :].reshape(150, 150)
random3[2, :, :] = random1[2, :].reshape(150, 150)
random_image = random3
save_image(torch.from_numpy(random_image), 'try/random_' + str(it) + '_' + str(k) + '.png')
##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水
代码运行结果如下所示:
compare_0_0.png:
compare_0_1.png:
compare_0_2.png:
compare_9_0.png:
compare_9_1.png:
compare_9_2.png:
compare_19_0.png:
compare_19_1.png:
compare_19_2.png:
。。。。。。
random_0_0.png:
random_0_1.png:
random_0_2.png:
random_9_0.png:
random_9_1.png:
random_9_2.png:
random_19_0.png:
random_19_1.png:
random_19_2.png:
。。。。。。
可以看出,随着迭代次数增加,图片生成质量也会越高。
代码中对隐含层的z_sample位置再次进行了高斯随机采样以生成新的人脸图像,对于对抗变分自编码器来说这是否合理?(AVB的隐含层没有显式的概率分布,其为一个黑箱模型)对于AVB有没有更好的生成新的人脸图像的采样方法呢?
这篇对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (二)就先写到这里,若有疏漏、不恰当或者错误的地方还请及时指出。另外,代码还需进一步的优化,若你有更好的修改方式或想法,请不吝赐教。
这里附上上一篇博文的链接:对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (一):https://blog.csdn.net/S20144144/article/details/99467235