2022.2.19 !!!此文尚未完工!!!
仅供简单参考
论文地址:https://arxiv.org/abs/1904.00370
代码:https://github.com/sinhasam/vaal
发表于:ICCV’19
VAAL主要是由一个VAE,一个Discriminator以及相应配套的训练流程构成,接下来我们将对这几部分进行分别介绍。
VAE位于model.py中的VAE类中,代码如下:
class VAE(nn.Module):
"""Encoder-Decoder architecture for both WAE-MMD and WAE-GAN."""
def __init__(self, z_dim=32, nc=3):
super(VAE, self).__init__()
self.z_dim = z_dim
self.nc = nc
self.encoder = nn.Sequential(
nn.Conv2d(nc, 128, 4, 2, 1, bias=False), # B, 128, 32, 32
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False), # B, 256, 16, 16
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False), # B, 512, 8, 8
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.Conv2d(512, 1024, 4, 2, 1, bias=False), # B, 1024, 4, 4
nn.BatchNorm2d(1024),
nn.ReLU(True),
View((-1, 1024*2*2)), # B, 1024*4*4
)
self.fc_mu = nn.Linear(1024*2*2, z_dim) # B, z_dim
self.fc_logvar = nn.Linear(1024*2*2, z_dim) # B, z_dim
self.decoder = nn.Sequential(
nn.Linear(z_dim, 1024*4*4), # B, 1024*8*8
View((-1, 1024, 4, 4)), # B, 1024, 8, 8
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), # B, 512, 16, 16
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # B, 256, 32, 32
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # B, 128, 64, 64
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, nc, 1), # B, nc, 64, 64
)
self.weight_init()
def weight_init(self):
for block in self._modules:
try:
for m in self._modules[block]:
kaiming_init(m)
except:
kaiming_init(block)
def forward(self, x):
z = self._encode(x)
mu, logvar = self.fc_mu(z), self.fc_logvar(z)
z = self.reparameterize(mu, logvar)
x_recon = self._decode(z)
return x_recon, z, mu, logvar
def reparameterize(self, mu, logvar):
stds = (0.5 * logvar).exp()
epsilon = torch.randn(*mu.size())
if mu.is_cuda:
stds, epsilon = stds.cuda(), epsilon.cuda()
latents = epsilon * stds + mu
return latents
def _encode(self, x):
return self.encoder(x)
def _decode(self, z):
return self.decoder(z)
比较需要注意的是这里VAE有四个返回值:
x_recon: 经过完整"编码-解码"流程所得到的重建结果。
z 输入x经过编码器编码得到的潜在向量
mp
logvar
Discriminator位于model.py中的Discriminator类中,同大多数判别器类似,这里也只需要使用三个全连接层便足以实现判别的功能。
class Discriminator(nn.Module):
"""Adversary architecture(Discriminator) for WAE-GAN."""
def __init__(self, z_dim=10):
super(Discriminator, self).__init__()
self.z_dim = z_dim
self.net = nn.Sequential(
nn.Linear(z_dim, 512),
nn.ReLU(True),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Linear(512, 1),
nn.Sigmoid()
)
self.weight_init()
def weight_init(self):
# kaiming_init ...
def forward(self, z):
return self.net(z)
Sampler位于sampler.py中的Sampler类中判别器的处理过程如下。首先,对于拿到的数据(先不管这些数据是怎么拿到的),将其送入VAE进行编码,然后将编码得到的结果送入判别器:
all_preds = []
all_indices = []
for images, _, indices in data:
# ...
with torch.no_grad():
_, _, mu, _ = vae(images)
preds = discriminator(mu)
preds = preds.cpu().data
all_preds.extend(preds)
all_indices.extend(indices)
上面的all_preds存储的判别器所有的判别结果,而all_indices表示相应判别结果对应的索引。
判别器给予的得分越高,说明其越可能是已标注样本,反之则是未标注样本。因此,我们现在只需要将判别器所有的得分进行排序,然后选出得分最高的若干个(视budget而定)进行标注即可,有:
all_preds = torch.stack(all_preds)
all_preds = all_preds.view(-1)
all_preds *= -1
_, querry_indices = torch.topk(all_preds, int(self.budget))
这样就可以拿到最小值所对应的索引,最后得到索引的相对位置,将其返回即可:
querry_pool_indices = np.asarray(all_indices)[querry_indices]
完整代码如下所示:
class AdversarySampler:
def __init__(self, budget):
self.budget = budget
def sample(self, vae, discriminator, data, cuda):
all_preds = []
all_indices = []
for images, _, indices in data:
if cuda:
images = images.cuda()
with torch.no_grad():
_, _, mu, _ = vae(images)
preds = discriminator(mu)
preds = preds.cpu().data
all_preds.extend(preds)
all_indices.extend(indices)
all_preds = torch.stack(all_preds)
all_preds = all_preds.view(-1)
# need to multiply by -1 to be able to use torch.topk
all_preds *= -1
# select the points which the discriminator things are the most likely to be unlabeled
_, querry_indices = torch.topk(all_preds, int(self.budget))
querry_pool_indices = np.asarray(all_indices)[querry_indices]
return querry_pool_indices
sovler作用相当于main方法,其中定义了网络的训练、测试、验证流程。在介绍各自的流程之前,首先看一下本文的数据处理方法。与常见的构建dataloader再for循环迭代相比,本文的数据读取方式比较特别,如下:
def read_data(self, dataloader, labels=True):
if labels:
while True:
for img, label, _ in dataloader:
yield img, label
else:
while True:
for img, _, _ in dataloader:
yield img
然后看训练:
整个训练过程可以简化为如下:
for iter in range(max_iteration):
# 取一个batch的data
# 更新VAE
# 更新判别器
首先是一些基本的参数配置定义:
self.args.train_iterations = (self.args.num_images * self.args.train_epochs) // self.args.batch_size
lr_change = self.args.train_iterations // 4
labeled_data = self.read_data(querry_dataloader)
unlabeled_data = self.read_data(unlabeled_dataloader, labels=False)
optim_vae = optim.Adam(vae.parameters(), lr=5e-4)
optim_task_model = optim.SGD(task_model.parameters(), lr=0.01, weight_decay=5e-4, momentum=0.9)
optim_discriminator = optim.Adam(discriminator.parameters(), lr=5e-4)
vae.train()
discriminator.train()
task_model.train()
注意这里有一个很有意思的参数叫train_iterations,其实就是将网络的最大迭代epoch次数转化为最大迭代iteration次数,方法记为将每个iteration的图像数乘以epoch数最后再除batch_size。因为在后面可以看到,训练时的外层循环并不是epoch,而是换成了iteration,因此要做这个转换。
然后开始迭代:
for iter_count in range(self.args.train_iterations):
首先是一些学习率的调整以及数据读取操作:
if iter_count is not 0 and iter_count % lr_change == 0:
for param in optim_task_model.param_groups:
param['lr'] = param['lr'] / 10
labeled_imgs, labels = next(labeled_data)
unlabeled_imgs = next(unlabeled_data)
整个的训练流程分三步:训练任务模型,训练VAE以及训练判别器,接下来将一一进行介绍。
首先是训练任务模型。那么这里只需要把已标注的数据喂进vgg然后优化即可:
preds = task_model(labeled_imgs)
task_loss = self.ce_loss(preds, labels)
optim_task_model.zero_grad()
task_loss.backward()
optim_task_model.step()
然后是训练VAE:
for count in range(self.args.num_vae_steps):
recon, z, mu, logvar = vae(labeled_imgs)
unsup_loss = self.vae_loss(labeled_imgs, recon, mu, logvar, self.args.beta)
unlab_recon, unlab_z, unlab_mu, unlab_logvar = vae(unlabeled_imgs)
transductive_loss = self.vae_loss(unlabeled_imgs,
unlab_recon, unlab_mu, unlab_logvar, self.args.beta)
labeled_preds = discriminator(mu)
unlabeled_preds = discriminator(unlab_mu)
lab_real_preds = torch.ones(labeled_imgs.size(0))
unlab_real_preds = torch.ones(unlabeled_imgs.size(0))
if self.args.cuda:
lab_real_preds = lab_real_preds.cuda()
unlab_real_preds = unlab_real_preds.cuda()
dsc_loss = self.bce_loss(labeled_preds, lab_real_preds) + \
self.bce_loss(unlabeled_preds, unlab_real_preds)
total_vae_loss = unsup_loss + transductive_loss + self.args.adversary_param * dsc_loss
optim_vae.zero_grad()
total_vae_loss.backward()
optim_vae.step()
# sample new batch if needed to train the adversarial network
if count < (self.args.num_vae_steps - 1):
labeled_imgs, _ = next(labeled_data)
unlabeled_imgs = next(unlabeled_data)
if self.args.cuda:
labeled_imgs = labeled_imgs.cuda()
unlabeled_imgs = unlabeled_imgs.cuda()
labels = labels.cuda()
最后是训练判别器:
for count in range(self.args.num_adv_steps):
with torch.no_grad():
_, _, mu, _ = vae(labeled_imgs)
_, _, unlab_mu, _ = vae(unlabeled_imgs)
labeled_preds = discriminator(mu)
unlabeled_preds = discriminator(unlab_mu)
lab_real_preds = torch.ones(labeled_imgs.size(0))
unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0))
if self.args.cuda:
lab_real_preds = lab_real_preds.cuda()
unlab_fake_preds = unlab_fake_preds.cuda()
dsc_loss = self.bce_loss(labeled_preds, lab_real_preds) + \
self.bce_loss(unlabeled_preds, unlab_fake_preds)
optim_discriminator.zero_grad()
dsc_loss.backward()
optim_discriminator.step()
# sample new batch if needed to train the adversarial network
if count < (self.args.num_adv_steps - 1):
labeled_imgs, _ = next(labeled_data)
unlabeled_imgs = next(unlabeled_data)
if self.args.cuda:
labeled_imgs = labeled_imgs.cuda()
unlabeled_imgs = unlabeled_imgs.cuda()
labels = labels.cuda()
首先根据不同的数据集制定相应的参数策略,有:
if args.dataset == 'cifar10':
test_dataloader = data.DataLoader(
datasets.CIFAR10(args.data_path, download=True, transform=cifar_transformer(), train=False),
batch_size=args.batch_size, drop_last=False)
train_dataset = CIFAR10(args.data_path)
args.num_images = 50000
args.num_val = 5000
args.budget = 2500
args.initial_budget = 5000
args.num_classes = 10
elif args.dataset == 'cifar100':
# ...
elif args.dataset == 'imagenet':
# ...
else:
raise NotImplementedError
这里以cifar10为例。首先,下载并读取cifar的测试集,存储至test_dataloader中;类似地,准备一个训练数据集,但此时不用dataloader读取。接下来就是一些超参数:
num_images: 训练集的图像数。由于cifar10的训练集包含50000张图像,因此这里的值取50000
num_val:验证集的图像数。在本文中,我们留出训练集中10%的图像作为验证集,因此设为5000
budget:
initial_budget: 初始训练的样本数。这里相当于指在一开始50000张训练集图像中只有5000张(10%)的图像是已经标注好的可以拿来训练,而剩余45000张尚未标注。
num_classes: (适用于图像分类)类别数。cifar10包含10类图像,因此这里的值取10
然后,从训练集中划分出验证集。此时训练集为all_indices,验证集为initial_indices。此外,由于训练集中的图像并不是"一开始全部就标注好"的,因此我们还要从训练集中选出用于初始训练的初始集initial_indices:
all_indices = set(np.arange(args.num_images))
val_indices = random.sample(all_indices, args.num_val)
all_indices = np.setdiff1d(list(all_indices), val_indices)
initial_indices = random.sample(list(all_indices), args.initial_budget)
这里索引用set而非list是为了方便使用random.sample进行随机抽样。接着准备sampler。这里使用sampler的作用是迫使dataloader只选取训练集中"已标注"的图像,而无需去修改训练集。
sampler = data.sampler.SubsetRandomSampler(initial_indices)
val_sampler = data.sampler.SubsetRandomSampler(val_indices)
然后则是以供训练与验证的dataloader:
querry_dataloader = data.DataLoader(train_dataset, sampler=sampler,
batch_size=args.batch_size, drop_last=True)
val_dataloader = data.DataLoader(train_dataset, sampler=val_sampler,
batch_size=args.batch_size, drop_last=False)
后面还有个solver:
solver = Solver(args, test_dataloader)
然后有个split:
splits = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
还有个current_indices,表示当前已标注的样本数:
current_indices = list(initial_indices)
接下来就是开始正式的训练流程了:
for split in splits:
# ...
首先是初始化相应的任务模型(本文用的是VGG),VAE以及判别器:
task_model = vgg.vgg16_bn(num_classes=args.num_classes)
vae = model.VAE(args.latent_dim)
discriminator = model.Discriminator(args.latent_dim)
然后,区分出未标记的数据。由于current_indices表示当前"已标注"的数据,因此我们只需要做个全部索引与已标注索引的差集便可以得到未标注索引:
unlabeled_indices = np.setdiff1d(list(all_indices), current_indices)
至此,可以构建出未标注集的dataloader:
unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
unlabeled_dataloader = data.DataLoader(train_dataset,
sampler=unlabeled_sampler, batch_size=args.batch_size, drop_last=False)
现在,调用solver中的train以开始训练:
acc, vae, discriminator = solver.train(querry_dataloader,
val_dataloader,
task_model,
vae,
discriminator,
unlabeled_dataloader)
训练会返回三个东西,一个是当前的准确率,还有则是训练完的VAE与判别器。接下来要执行的就是主动学习的核心:从未标注数据集中选取一些新的数据进行标注:
sampled_indices = solver.sample_for_labeling(vae, discriminator, unlabeled_dataloader)
此时的训练数据则是之前的训练数据加上本轮新采样的数据:
current_indices = list(current_indices) + list(sampled_indices)
最后,重新采样,更新采样器:
sampler = data.sampler.SubsetRandomSampler(current_indices)
querry_dataloader = data.DataLoader(train_dataset, sampler=sampler,
batch_size=args.batch_size, drop_last=True)