In the second – global part of the training, we align the newly trained band with already encoded knowledge.
The simplest method to circumvent interference between bands is to partition the latent space of VAE and place new data representation in a separate area of latent space.
However, such an approach limits information sharing across separate tasks and hinders forward and backward knowledge transfer(这种方法限制了不同任务之间的信息共享,并阻碍了向前和向后的知识转移). Therefore, in Multiband VAE we propose to align different latent spaces through an additional neural network that we call translator. Translator maps individual latent spaces which are conditioned with task id into the common global one where examples are stored independently of their source task, as presented in Fig 2.
潜在空间划分(Partitioning the Latent Space):
将新数据表示放置在独立区域的潜在空间(Placing New Data Representation in a Separate Area of Latent Space):
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, latent_dim) # mean
self.fc22 = nn.Linear(hidden_dim, latent_dim) # log variance
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
class ConditionalVAE(VAE):
def __init__(self, input_dim, hidden_dim, latent_dim, num_classes):
super(ConditionalVAE, self).__init__(input_dim, hidden_dim, latent_dim)
self.class_emb = nn.Embedding(num_classes, hidden_dim)
def encode(self, x, y):
h1 = F.relu(self.fc1(x) + self.class_emb(y))
return self.fc21(h1), self.fc22(h1)
def train(model, data_loader, optimizer, epoch, device):
train_loss = 0
for batch_idx, (data, labels) in enumerate(data_loader):
data =
labels =
recon_batch, mu, logvar = model(data, labels)
loss = loss_function(recon_batch, data, mu, logvar)
train_loss += loss.item()
print('Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(data_loader.dataset)))