代码来源 github
依据文章 beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework
主程序如下:运行环境gpu,使用torch框架
主程序主要用于定义程序后续用到的变量以及设置运行环境,方便后续修改调试,一目了然。
import argparse
import numpy as np
import torch
from solver import Solver
from utils import str2bool
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def main(args):
seed = args.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
net = Solver(args)
if args.train:
net.train()
else:
net.traverse()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Beta-VAE')
parser.add_argument('--train', default=True, type=str2bool, help='train or traverse')
parser.add_argument('--seed', default=1, type=int, help='random seed')
parser.add_argument('--cuda', default=True, type=str2bool, help='enable cuda')
parser.add_argument('--max_iter', default=1e6, type=float, help='maximum training iteration')
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--z_dim', default=10, type=int, help='dimension of the representation z')
parser.add_argument('--beta', default=4, type=float, help='beta parameter for KL-term in original beta-VAE')
parser.add_argument('--objective', default='H', type=str, help='beta-vae objective proposed in Higgins et al. or Burgess et al. H/B')
parser.add_argument('--model', default='H', type=str, help='model proposed in Higgins et al. or Burgess et al. H/B')
parser.add_argument('--gamma', default=1000, type=float, help='gamma parameter for KL-term in understanding beta-VAE')
parser.add_argument('--C_max', default=25, type=float, help='capacity parameter(C) of bottleneck channel')
parser.add_argument('--C_stop_iter', default=1e5, type=float, help='when to stop increasing the capacity')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--beta1', default=0.9, type=float, help='Adam optimizer beta1')
parser.add_argument('--beta2', default=0.999, type=float, help='Adam optimizer beta2')
parser.add_argument('--dset_dir', default='data', type=str, help='dataset directory')
parser.add_argument('--dataset', default='CelebA', type=str, help='dataset name')
parser.add_argument('--image_size', default=64, type=int, help='image size. now only (64,64) is supported')
parser.add_argument('--num_workers', default=2, type=int, help='dataloader num_workers')
parser.add_argument('--viz_on', default=True, type=str2bool, help='enable visdom visualization')
parser.add_argument('--viz_name', default='main', type=str, help='visdom env name')
parser.add_argument('--viz_port', default=8097, type=str, help='visdom port number')
parser.add_argument('--save_output', default=True, type=str2bool, help='save traverse images and gif')
parser.add_argument('--output_dir', default='outputs', type=str, help='output directory')
parser.add_argument('--gather_step', default=1000, type=int, help='numer of iterations after which data is gathered for visdom')
parser.add_argument('--display_step', default=10000, type=int, help='number of iterations after which loss data is printed and visdom is updated')
parser.add_argument('--save_step', default=10000, type=int, help='number of iterations after which a checkpoint is saved')
parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename')
args = parser.parse_args()
main(args)
dataset
该程序中主要用到celeA数据集,另一数据集为测试模型泛化能力试验数据集,CeleA是香港中文大学的开放数据,包含10177个名人身份的202599张图片,并且都做好了特征标记,这对人脸相关的训练是非常好用的数据集。
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
def is_power_of_2(num):
return ((num & (num - 1)) == 0) and num != 0
class CustomImageFolder(ImageFolder):
def __init__(self, root, transform=None):
super(CustomImageFolder, self).__init__(root, transform)
def __getitem__(self, index):
path = self.imgs[index][0]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img
class CustomTensorDataset(Dataset):
def __init__(self, data_tensor):
self.data_tensor = data_tensor
def __getitem__(self, index):
return self.data_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
def return_data(args):
name = args.dataset
dset_dir = args.dset_dir
batch_size = args.batch_size
num_workers = args.num_workers
image_size = args.image_size
assert image_size == 64, 'currently only image size of 64 is supported'
if name.lower() == '3dchairs':
root = os.path.join(dset_dir, '3DChairs')
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),])
train_kwargs = {'root':root, 'transform':transform}
dset = CustomImageFolder
elif name.lower() == 'celeba':
root = os.path.join(dset_dir, 'CelebA')
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),])
train_kwargs = {'root':root, 'transform':transform}
dset = CustomImageFolder
elif name.lower() == 'dsprites':
root = os.path.join(dset_dir, 'dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
if not os.path.exists(root):
import subprocess
print('Now download dsprites-dataset')
subprocess.call(['./download_dsprites.sh'])
print('Finished')
data = np.load(root, encoding='bytes')
data = torch.from_numpy(data['imgs']).unsqueeze(1).float()
train_kwargs = {'data_tensor':data}
dset = CustomTensorDataset
else:
raise NotImplementedError
train_data = dset(**train_kwargs)
train_loader = DataLoader(train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True)
data_loader = train_loader
return data_loader
if __name__ == '__main__':
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),])
dset = CustomImageFolder('data/CelebA', transform)
loader = DataLoader(dset,
batch_size=32,
shuffle=True,
num_workers=1,
pin_memory=False,
drop_last=True)
images1 = iter(loader).next()
import ipdb; ipdb.set_trace()
model
import torch
import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std*eps
class View(nn.Module):
def __init__(self, size):
super(View, self).__init__()
self.size = size
def forward(self, tensor):
return tensor.view(self.size)
class BetaVAE_H(nn.Module):
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
def __init__(self, z_dim=10, nc=3):
super(BetaVAE_H, self).__init__()
self.z_dim = z_dim
self.nc = nc
self.encoder = nn.Sequential(
nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32
nn.ReLU(True),
nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16
nn.ReLU(True),
nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8
nn.ReLU(True),
nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4
nn.ReLU(True),
nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1
nn.ReLU(True),
View((-1, 256*1*1)), # B, 256
nn.Linear(256, z_dim*2), # B, z_dim*2
)
self.decoder = nn.Sequential(
nn.Linear(z_dim, 256), # B, 256
View((-1, 256, 1, 1)), # B, 256, 1, 1
nn.ReLU(True),
nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4
nn.ReLU(True),
nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8
nn.ReLU(True),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16
nn.ReLU(True),
nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32
nn.ReLU(True),
nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64
)
self.weight_init()
def weight_init(self):
for block in self._modules:
for m in self._modules[block]:
kaiming_init(m)
def forward(self, x):
distributions = self._encode(x)
mu = distributions[:, :self.z_dim]
logvar = distributions[:, self.z_dim:]
z = reparametrize(mu, logvar)
x_recon = self._decode(z)
return x_recon, mu, logvar
def _encode(self, x):
return self.encoder(x)
def _decode(self, z):
return self.decoder(z)
class BetaVAE_B(BetaVAE_H):
"""Model proposed in understanding beta-VAE paper(Burgess et al, arxiv:1804.03599, 2018)."""
def __init__(self, z_dim=10, nc=1):
super(BetaVAE_B, self).__init__()
self.nc = nc
self.z_dim = z_dim
self.encoder = nn.Sequential(
nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32
nn.ReLU(True),
nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16
nn.ReLU(True),
nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 8, 8
nn.ReLU(True),
nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 4, 4
nn.ReLU(True),
View((-1, 32*4*4)), # B, 512
nn.Linear(32*4*4, 256), # B, 256
nn.ReLU(True),
nn.Linear(256, 256), # B, 256
nn.ReLU(True),
nn.Linear(256, z_dim*2), # B, z_dim*2
)
self.decoder = nn.Sequential(
nn.Linear(z_dim, 256), # B, 256
nn.ReLU(True),
nn.Linear(256, 256), # B, 256
nn.ReLU(True),
nn.Linear(256, 32*4*4), # B, 512
nn.ReLU(True),
View((-1, 32, 4, 4)), # B, 32, 4, 4
nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 8, 8
nn.ReLU(True),
nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 16, 16
nn.ReLU(True),
nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32
nn.ReLU(True),
nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64
)
self.weight_init()
def weight_init(self):
for block in self._modules:
for m in self._modules[block]:
kaiming_init(m)
def forward(self, x):
distributions = self._encode(x)
mu = distributions[:, :self.z_dim]
logvar = distributions[:, self.z_dim:]
z = reparametrize(mu, logvar)
x_recon = self._decode(z).view(x.size())
return x_recon, mu, logvar
def _encode(self, x):
return self.encoder(x)
def _decode(self, z):
return self.decoder(z)
def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
def normal_init(m, mean, std):
if isinstance(m, (nn.Linear, nn.Conv2d)):
m.weight.data.normal_(mean, std)
if m.bias.data is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
m.weight.data.fill_(1)
if m.bias.data is not None:
m.bias.data.zero_()
if __name__ == '__main__':
pass
solver
重构损失部分有新分布函数替换对比。常用一般为高斯分布,这里有替换为伯努利分布操作,用于对比实验效果。此子程序还用于模型的训练,并保存或加载训练好的模型。
import warnings
warnings.filterwarnings("ignore")
import os
from tqdm import tqdm
import visdom
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from utils import cuda, grid2gif
from model import BetaVAE_H, BetaVAE_B
from dataset import return_data
def reconstruction_loss(x, x_recon, distribution):
batch_size = x.size(0)
assert batch_size != 0
if distribution == 'bernoulli':
recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
elif distribution == 'gaussian':
x_recon = F.sigmoid(x_recon)
recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
else:
recon_loss = None
return recon_loss
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class DataGather(object):
def __init__(self):
self.data = self.get_empty_data_dict()
def get_empty_data_dict(self):
return dict(iter=[],
recon_loss=[],
total_kld=[],
dim_wise_kld=[],
mean_kld=[],
mu=[],
var=[],
images=[],)
def insert(self, **kwargs):
for key in kwargs:
self.data[key].append(kwargs[key])
def flush(self):
self.data = self.get_empty_data_dict()
class Solver(object):
def __init__(self, args):
self.use_cuda = args.cuda and torch.cuda.is_available()
self.max_iter = args.max_iter
self.global_iter = 0
self.z_dim = args.z_dim
self.beta = args.beta
self.gamma = args.gamma
self.C_max = args.C_max
self.C_stop_iter = args.C_stop_iter
self.objective = args.objective
self.model = args.model
self.lr = args.lr
self.beta1 = args.beta1
self.beta2 = args.beta2
if args.dataset.lower() == 'dsprites':
self.nc = 1
self.decoder_dist = 'bernoulli'
elif args.dataset.lower() == '3dchairs':
self.nc = 3
self.decoder_dist = 'gaussian'
elif args.dataset.lower() == 'celeba':
self.nc = 3
self.decoder_dist = 'gaussian'
else:
raise NotImplementedError
if args.model == 'H':
net = BetaVAE_H
elif args.model == 'B':
net = BetaVAE_B
else:
raise NotImplementedError('only support model H or B')
self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
betas=(self.beta1, self.beta2))
self.viz_name = args.viz_name
self.viz_port = args.viz_port
self.viz_on = args.viz_on
self.win_recon = None
self.win_kld = None
self.win_mu = None
self.win_var = None
if self.viz_on:
self.viz = visdom.Visdom(port=self.viz_port)
self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name)
if not os.path.exists(self.ckpt_dir):
os.makedirs(self.ckpt_dir, exist_ok=True)
self.ckpt_name = args.ckpt_name
if self.ckpt_name is not None:
self.load_checkpoint(self.ckpt_name)
self.save_output = args.save_output
self.output_dir = os.path.join(args.output_dir, args.viz_name)
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
self.gather_step = args.gather_step
self.display_step = args.display_step
self.save_step = args.save_step
self.dset_dir = args.dset_dir
self.dataset = args.dataset
self.batch_size = args.batch_size
self.data_loader = return_data(args)
self.gather = DataGather()
def train(self):
self.net_mode(train=True)
self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
out = False
pbar = tqdm(total=self.max_iter)
pbar.update(self.global_iter)
while not out:
for x in self.data_loader:
self.global_iter += 1
pbar.update(1)
x = Variable(cuda(x, self.use_cuda))
x_recon, mu, logvar = self.net(x)
recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
if self.objective == 'H':
beta_vae_loss = recon_loss + self.beta*total_kld
elif self.objective == 'B':
C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0])
beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()
self.optim.zero_grad()
beta_vae_loss.backward()
self.optim.step()
if self.viz_on and self.global_iter%self.gather_step == 0:
self.gather.insert(iter=self.global_iter,
mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
recon_loss=recon_loss.data, total_kld=total_kld.data,
dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
if self.global_iter%self.display_step == 0:
pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
var = logvar.exp().mean(0).data
var_str = ''
for j, var_j in enumerate(var):
var_str += 'var{}:{:.4f} '.format(j+1, var_j)
pbar.write(var_str)
if self.objective == 'B':
pbar.write('C:{:.3f}'.format(C.data[0]))
if self.viz_on:
self.gather.insert(images=x.data)
self.gather.insert(images=F.sigmoid(x_recon).data)
self.viz_reconstruction()
self.viz_lines()
self.gather.flush()
if self.viz_on or self.save_output:
self.viz_traverse()
if self.global_iter%self.save_step == 0:
self.save_checkpoint('last')
pbar.write('Saved checkpoint(iter:{})'.format(self.global_iter))
if self.global_iter%50000 == 0:
self.save_checkpoint(str(self.global_iter))
if self.global_iter >= self.max_iter:
out = True
break
pbar.write("[Training Finished]")
pbar.close()
def viz_reconstruction(self):
self.net_mode(train=False)
x = self.gather.data['images'][0][:100]
x = make_grid(x, normalize=True)
x_recon = self.gather.data['images'][1][:100]
x_recon = make_grid(x_recon, normalize=True)
images = torch.stack([x, x_recon], dim=0).cpu()
self.viz.images(images, env=self.viz_name+'_reconstruction',
opts=dict(title=str(self.global_iter)), nrow=10)
self.net_mode(train=True)
def viz_lines(self):
self.net_mode(train=False)
recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
mus = torch.stack(self.gather.data['mu']).cpu()
vars = torch.stack(self.gather.data['var']).cpu()
dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld'])
mean_klds = torch.stack(self.gather.data['mean_kld'])
total_klds = torch.stack(self.gather.data['total_kld'])
klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
iters = torch.Tensor(self.gather.data['iter'])
legend = []
for z_j in range(self.z_dim):
legend.append('z_{}'.format(z_j))
legend.append('mean')
legend.append('total')
if self.win_recon is None:
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
xlabel='iteration',
title='reconsturction loss',))
else:
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
win=self.win_recon,
update='append',
opts=dict(
width=400,
height=400,
xlabel='iteration',
title='reconsturction loss',))
if self.win_kld is None:
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend,
xlabel='iteration',
title='kl divergence',))
else:
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
win=self.win_kld,
update='append',
opts=dict(
width=400,
height=400,
legend=legend,
xlabel='iteration',
title='kl divergence',))
if self.win_mu is None:
self.win_mu = self.viz.line(
X=iters,
Y=mus,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
else:
self.win_mu = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_mu,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
if self.win_var is None:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
else:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_var,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
self.net_mode(train=True)
def viz_traverse(self, limit=3, inter=2/3, loc=-1):
self.net_mode(train=False)
import random
decoder = self.net.decoder
encoder = self.net.encoder
interpolation = torch.arange(-limit, limit+0.1, inter)
n_dsets = len(self.data_loader.dataset)
rand_idx = random.randint(1, n_dsets-1)
random_img = self.data_loader.dataset.__getitem__(rand_idx)
random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
random_img_z = encoder(random_img)[:, :self.z_dim]
random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
if self.dataset == 'dsprites':
fixed_idx1 = 87040 # square
fixed_idx2 = 332800 # ellipse
fixed_idx3 = 578560 # heart
fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
'fixed_heart':fixed_img_z3, 'random_img':random_img_z}
else:
fixed_idx = 0
fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
gifs = []
for key in Z.keys():
z_ori = Z[key]
samples = []
for row in range(self.z_dim):
if loc != -1 and row != loc:
continue
z = z_ori.clone()
for val in interpolation:
z[:, row] = val
sample = F.sigmoid(decoder(z)).data
samples.append(sample)
gifs.append(sample)
samples = torch.cat(samples, dim=0).cpu()
title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
if self.viz_on:
self.viz.images(samples, env=self.viz_name+'_traverse',
opts=dict(title=title), nrow=len(interpolation))
if self.save_output:
output_dir = os.path.join(self.output_dir, str(self.global_iter))
os.makedirs(output_dir, exist_ok=True)
gifs = torch.cat(gifs)
gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
for i, key in enumerate(Z.keys()):
for j, val in enumerate(interpolation):
save_image(tensor=gifs[i][j].cpu(),
filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),
nrow=self.z_dim, pad_value=1)
grid2gif(os.path.join(output_dir, key+'*.jpg'),
os.path.join(output_dir, key+'.gif'), delay=10)
self.net_mode(train=True)
def net_mode(self, train):
if not isinstance(train, bool):
raise('Only bool type is supported. True or False')
if train:
self.net.train()
else:
self.net.eval()
def save_checkpoint(self, filename, silent=True):
model_states = {'net':self.net.state_dict(),}
optim_states = {'optim':self.optim.state_dict(),}
win_states = {'recon':self.win_recon,
'kld':self.win_kld,
'mu':self.win_mu,
'var':self.win_var,}
states = {'iter':self.global_iter,
'win_states':win_states,
'model_states':model_states,
'optim_states':optim_states}
file_path = os.path.join(self.ckpt_dir, filename)
with open(file_path, mode='wb+') as f:
torch.save(states, f)
if not silent:
print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))
def load_checkpoint(self, filename):
file_path = os.path.join(self.ckpt_dir, filename)
if os.path.isfile(file_path):
checkpoint = torch.load(file_path)
self.global_iter = checkpoint['iter']
self.win_recon = checkpoint['win_states']['recon']
self.win_kld = checkpoint['win_states']['kld']
self.win_var = checkpoint['win_states']['var']
self.win_mu = checkpoint['win_states']['mu']
self.net.load_state_dict(checkpoint['model_states']['net'])
self.optim.load_state_dict(checkpoint['optim_states']['optim'])
print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
else:
print("=> no checkpoint found at '{}'".format(file_path))
utils
用于程序运行的预准备阶段,负责相关模型结构的下载,以及加载。
import argparse
import subprocess
import torch
import torch.nn as nn
from torch.autograd import Variable
def cuda(tensor, uses_cuda):
return tensor.cuda() if uses_cuda else tensor
def str2bool(v):
# codes from : https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def where(cond, x, y):
"""Do same operation as np.where
code from:
https://discuss.pytorch.org/t/how-can-i-do-the-operation-the-same-as-np-where/1329/8
"""
cond = cond.float()
return (cond*x) + ((1-cond)*y)
def grid2gif(image_str, output_gif, delay=100):
"""Make GIF from images.
code from:
https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python/34555939#34555939
"""
str1 = 'convert -delay '+str(delay)+' -loop 0 ' + image_str + ' ' + output_gif
subprocess.call(str1, shell=True)
第一次正式发博文,不喜勿喷,共同进步,欢迎交流!
祝读者,程运昌盛,少些bug!