notations
Encoder
Decoder
PriorEncoder
model structure
class CVAE(nn.Module):
def __init__(self, config):
super(CVAE, self).__init__()
self.encoder = Encoder(...)
self.decoder = Decoder(...)
self.priorEncoder = PriorEncoder(...)
def forward(self, x, y):
x = x.reshape((-1, 784)) # MNIST
mu, sigma = self.encoder(x, y)
prior_mu, prior_sigma = self.priorEncoder(y)
z = torch.randn_like(mu)
z = z * sigma + mu
reconstructed_x = self.decoder(z, y)
reconstructed_x = reconstructed_x.reshape((-1, 28, 28))
return reconstructed_x, mu, sigma, prior_mu, prior_sigma
def infer(self, y):
prior_mu, prior_sigma = self.priorEncoder(y)
z = torch.randn_like(prior_mu)
z = z * prior_sigma + prior_mu
reconstructed_x = self.decoder(z, y)
return reconstructed_x
#
class Loss(nn.Module):
def __init__(self):
super(Loss,self).__init__()
self.loss_fn = nn.MSELoss(reduction='mean')
self.kld_loss_weight = 1e-5
def forward(self, x, reconstructed_x, mu, sigma, prior_mu, prior_sigma):
mse_loss = self.loss_fn(x, reconstructed_x)
kld_loss = torch.log(prior_sigma / sigma) + (sigma**2 + (mu - prior_mu)**2) / (2 * prior_sigma**2) - 0.5
kld_loss = torch.sum(kld_loss) / x.shape[0]
loss = mse_loss + self.kld_loss_weight * kld_loss
return loss
#
def train(model, criterion, optimizer, data_loader, config):
train_task_time_str = time_str()
for epoch in range(config.num_epoch):
loss_seq = []
for step, (x,y) in tqdm(enumerate(data_loader)):
# -------------------- data --------------------
x = x.to(device)
y = y.to(device)
# -------------------- forward --------------------
reconstructed_x, mu, sigma, prior_mu, prior_sigma = model(x, y)
loss = criterion(x, reconstructed_x, mu, sigma, prior_mu, prior_sigma)
# -------------------- log --------------------
loss_seq.append(loss.item())
# -------------------- backward --------------------
optimizer.zero_grad()
loss.backward()
optimizer.step()
# -------------------- end --------------------
logging.info(f'epoch {epoch:^5d} loss {sum(loss_seq[-config.batch_size:]) / config.batch_size:.5f}')
with torch.no_grad():
# -------------------- file --------------------
path = f'{config.save_fig_path}/{train_task_time_str}' # type(model).__name__
if not os.path.exists(path):
os.makedirs(path)
path += f'/epoch{epoch:04d}.png'
# -------------------- figure --------------------
plt.close()
fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(10, 2), dpi=512)
fig.suptitle(f'epoch {epoch} loss {sum(loss_seq[-config.batch_size:]) / config.batch_size:.5f}')
# -------------------- infer --------------------
y = torch.Tensor(list(range(config.num_class)))
y = y.to(dtype=torch.int64)
y = nn.functional.one_hot(y, num_classes=config.num_class)
y = y.to(dtype=torch.float)
y = y.to(device)
x = model.infer(y)
x = x.cpu()
x = x.numpy()
x += x.min()
x /= x.max()
x *= 255
x = x.astype(np.uint8)
# -------------------- plot --------------------
for idx,ax,arr in zip(range(config.num_class),axs,x):
ax.set_title(str(idx))
ax.axis('off')
ax.imshow(arr.reshape((28,28)), cmap='BuGn')
# -------------------- save --------------------
# plt.show()
plt.savefig(path)
# -------------------- end --------------------
#
dynamics
kld_loss_weight = 1e-5
{'batch_size': 25,
'conv_encoder': True,
'learning_rate': 1e-05,
'num_class': 10,
'num_epoch': 16,
'save_fig_path': './figs',
'use_cuda': True}
以上这组超参数能较快的收敛到较优模型参数.
实验发现, batch_size
较大时收敛到较差模型参数, learning_rate
较小时收敛非常缓慢.