解读:【半监督图像分割】2021-CPS CVPR_說詤榢的博客-CSDN博客
[CVPR 2021] CPS: 基于交叉伪监督的半监督语义分割 - 知乎 (zhihu.com)
论文:https://arxiv.org/abs/2106.01226v2
代码:https://github.com/charlesCXK/TorchSemiSeg
图1.(a)本文方法;(b)CPC;(c)Mean-Teacher;(d)FixMatch。
论文为半监督语义分割任务设计了一种非常简洁而又性能很好的算法:cross pseudo supervision (CPS)。训练时,使用两个相同结构、但是不同初始化的网络,添加约束使得两个网络对同一样本的输出是相似的。具体来说,当前网络产生的one-hot pseudo label,会作为另一路网络预测的目标,这个过程用cross entropy loss监督。
半监督分割的工作简单可分为两种:self-training和consistency learning。一般来说,self-training是离线处理的过程,而consistency learning是在线处理的。
(1)Self-training
Self-training主要分为3步。
(2)Consistency learning
Consistency learning的核心idea是:鼓励模型对经过不同变换的同一样本有相似的输出。这里“变换”包括高斯噪声、随机旋转、颜色的改变等等。
Consistency learning基于两个假设:平滑假设和 聚类假设。
当前,Consistency learning主要有三类做法:Mean Teacher,CPC,PseudoSeg。
Mean teacher是17年提出的模型。给定一个输入图像X,添加不同的高斯噪声后得到X1和X2。将X1输入网络f(θ)中,得到预测P1;对f(θ)计算EMA,得到另一个网络,然后将X2输入这个EMA模型,得到另一个输出P2。最后,我们用P2作为P1的目标,用MSE loss约束。
PseudoSeg是google发表在ICLR 2021的工作。对输入的图像X做两次不同的数据增强,一种“弱增强”(random crop/resize/flip),一种“强增强”(color jittering)。他们将两个增强后图像输入同一个网络f(θ),得到两个不同的输出。因为“弱增强”下训练更加稳定,他们用“弱增强”后的图像作为target。
CPC是发表在ECCV 2020的工作(Guided Collaborative Training for Pixel-wise Semi-Supervised Learning)的简化版本。将同一图像输入两个不同网络,然后约束两个网络的输出是相似的。
从上面可以简单总结一下:
论文发现,self-training在数据量不那么小的时候,性能非常强。于是结合两种方式,提出了CPS:cross pseudo supervision。
图1.(a)本文方法;(b)CPC;(c)Mean-Teacher;(d)FixMatch。CPS设计简洁。训练时,使用两个网络f(θ1) 和 f(θ2)。对于同一个输入图像X,可以有两个不同的输出P1和P2。通过argmax操作得到对应的one-hot标签Y1和Y2。类似于self-training中的操作,将这两个伪标签作为监督信号。(用Y2作为P1的监督,Y1作为P2的监督,并用cross entropy loss约束。)
本文使用相同的结构,但是不同的初始化。用PyTorch框架中的kaiming_normal进行两次随机初始化,而没有对初始化的分布做特定的约束。
测试的时候,只使用其中一个网络进行inference,所以不增加任何测试/部署时候的开销。
简易表示: 其中,P_1,P_2表示预测结果,Y_1,Y_2为伪标签。
训练过程包括两个损失:监督损失L_s, 交叉伪监督损失L_cps。
监督损失使用Cross entropy。
有监督损失:两个分支都有监督。
交叉伪监督损失是双向的:一个支路使用Y_1监督P_2,另一个支路使用Y_2监督P_1。
论文应用CutMix图像增强方法来进行数据增强。
本文方法在VOC和Cityscapes两个数据集的几种不同的数据量情况下都达到了SOTA。使用不同backbone效果也好。
使用相同的结构,但是不同的初始化。 本文方法相比于baseline都有提升。
实验证明,本文损失组合相较于其他损失组合可能更优。
相比其他几种半监督方法,本文方式更好。
本文方式结合自训练,对两者都有益。
方法跟self-training进行比较。本文方法由于鼓励模型学习一个更加compact的特征编码,优于self-training。
分割实例。
# https://github.com/charlesCXK/TorchSemiSeg/tree/main/exp.city/city8.res50v3%2B.CPS
# 两个网络,结构都为deeplabv3+,但初始化值不同
class Network(nn.Module):
def __init__(self, num_classes, criterion, norm_layer, pretrained_model=None):
super(Network, self).__init__()
self.branch1 = SingleNetwork(num_classes, criterion, norm_layer, pretrained_model)
self.branch2 = SingleNetwork(num_classes, criterion, norm_layer, pretrained_model)
def forward(self, data, step=1):
if not self.training:
pred1 = self.branch1(data)
return pred1
if step == 1:
return self.branch1(data)
elif step == 2:
return self.branch2(data)
# train.py
parser = argparse.ArgumentParser()
os.environ['MASTER_PORT'] = '169711'
with Engine(custom_parser=parser) as engine:
args = parser.parse_args()
cudnn.benchmark = True
seed = config.seed
if engine.distributed:
seed = engine.local_rank
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# data loader + unsupervised data loader
train_loader, train_sampler = get_train_loader(engine, CityScape, train_source=config.train_source, \
unsupervised=False)
unsupervised_train_loader, unsupervised_train_sampler = get_train_loader(engine, CityScape, \
train_source=config.unsup_source, unsupervised=True)
if engine.distributed and (engine.local_rank == 0):
tb_dir = config.tb_dir + '/{}'.format(time.strftime("%b%d_%d-%H-%M", time.localtime()))
generate_tb_dir = config.tb_dir + '/tb'
logger = SummaryWriter(log_dir=tb_dir)
engine.link_tb(tb_dir, generate_tb_dir)
# config network and criterion
pixel_num = 50000 * config.batch_size // engine.world_size
criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7,
min_kept=pixel_num, use_weight=False)
criterion_cps = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)
if engine.distributed:
BatchNorm2d = SyncBatchNorm
# define and init the model
model = Network(config.num_classes, criterion=criterion,
pretrained_model=config.pretrained_model,
norm_layer=BatchNorm2d)
init_weight(model.branch1.business_layer, nn.init.kaiming_normal_,
BatchNorm2d, config.bn_eps, config.bn_momentum,
mode='fan_in', nonlinearity='relu')
init_weight(model.branch2.business_layer, nn.init.kaiming_normal_,
BatchNorm2d, config.bn_eps, config.bn_momentum,
mode='fan_in', nonlinearity='relu')
# define the learning rate
base_lr = config.lr
if engine.distributed:
base_lr = config.lr
params_list_l = []
params_list_l = group_weight(params_list_l, model.branch1.backbone,
BatchNorm2d, base_lr)
for module in model.branch1.business_layer:
params_list_l = group_weight(params_list_l, module, BatchNorm2d,
base_lr)
optimizer_l = torch.optim.SGD(params_list_l,
lr=base_lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
params_list_r = []
params_list_r = group_weight(params_list_r, model.branch2.backbone,
BatchNorm2d, base_lr)
for module in model.branch2.business_layer:
params_list_r = group_weight(params_list_r, module, BatchNorm2d,
base_lr) # head lr * 10
optimizer_r = torch.optim.SGD(params_list_r,
lr=base_lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
# config lr policy
total_iteration = config.nepochs * config.niters_per_epoch
lr_policy = WarmUpPolyLR(base_lr, config.lr_power, total_iteration, config.niters_per_epoch * config.warm_up_epoch)
if engine.distributed:
print('distributed !!')
if torch.cuda.is_available():
model.cuda()
model = DistributedDataParallel(model)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DataParallelModel(model, device_ids=engine.devices)
model.to(device)
engine.register_state(dataloader=train_loader, model=model,
optimizer_l=optimizer_l, optimizer_r=optimizer_r)
if engine.continue_state_object:
engine.restore_checkpoint() # it will change the state dict of optimizer also
model.train()
print('begin train')
for epoch in range(engine.state.epoch, config.nepochs):
if engine.distributed:
train_sampler.set_epoch(epoch)
bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
if is_debug:
pbar = tqdm(range(10), file=sys.stdout, bar_format=bar_format)
else:
pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, bar_format=bar_format)
dataloader = iter(train_loader)
unsupervised_dataloader = iter(unsupervised_train_loader)
sum_loss_sup = 0
sum_loss_sup_r = 0
sum_cps_loss = 0
''' supervised part '''
for idx in pbar:
start_time = time.time()
optimizer_l.zero_grad()
optimizer_r.zero_grad()
engine.update_iteration(epoch, idx)
minibatch = dataloader.next()
unsup_minibatch = unsupervised_dataloader.next()
imgs = minibatch['data']
gts = minibatch['label']
unsup_imgs = unsup_minibatch['data']
imgs = imgs.cuda(non_blocking=True)
unsup_imgs = unsup_imgs.cuda(non_blocking=True)
gts = gts.cuda(non_blocking=True)
_, pred_sup_l = model(imgs, step=1)
_, pred_sup_r = model(imgs, step=2)
_, pred_unsup_l = model(unsup_imgs, step=1)
_, pred_unsup_r = model(unsup_imgs, step=2)
pred_l = torch.cat([pred_sup_l, pred_unsup_l], dim=0)
pred_r = torch.cat([pred_sup_r, pred_unsup_r], dim=0)
_, max_l = torch.max(pred_l, dim=1)
_, max_r = torch.max(pred_r, dim=1)
max_l = max_l.long()
max_r = max_r.long()
cps_loss = criterion_cps(pred_l, max_r) + criterion_cps(pred_r, max_l)
dist.all_reduce(cps_loss, dist.ReduceOp.SUM)
cps_loss = cps_loss / engine.world_size
cps_loss = cps_loss * config.cps_weight
loss_sup = criterion(pred_sup_l, gts)
dist.all_reduce(loss_sup, dist.ReduceOp.SUM)
loss_sup = loss_sup / engine.world_size
loss_sup_r = criterion(pred_sup_r, gts)
dist.all_reduce(loss_sup_r, dist.ReduceOp.SUM)
loss_sup_r = loss_sup_r / engine.world_size
current_idx = epoch * config.niters_per_epoch + idx
lr = lr_policy.get_lr(current_idx)
optimizer_l.param_groups[0]['lr'] = lr
optimizer_l.param_groups[1]['lr'] = lr
for i in range(2, len(optimizer_l.param_groups)):
optimizer_l.param_groups[i]['lr'] = lr
optimizer_r.param_groups[0]['lr'] = lr
optimizer_r.param_groups[1]['lr'] = lr
for i in range(2, len(optimizer_r.param_groups)):
optimizer_r.param_groups[i]['lr'] = lr
# unsup_weight = config.unsup_weight
loss = loss_sup + loss_sup_r + cps_loss
loss.backward()
optimizer_l.step()
optimizer_r.step()
print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \
+ ' Iter{}/{}:'.format(idx + 1, config.niters_per_epoch) \
+ ' lr=%.2e' % lr \
+ ' loss_sup=%.2f' % loss_sup.item() \
+ ' loss_sup_r=%.2f' % loss_sup_r.item() \
+ ' loss_cps=%.4f' % cps_loss.item()
sum_loss_sup += loss_sup.item()
sum_loss_sup_r += loss_sup_r.item()
sum_cps_loss += cps_loss.item()
pbar.set_description(print_str, refresh=False)
end_time = time.time()
if engine.distributed and (engine.local_rank == 0):
logger.add_scalar('train_loss_sup', sum_loss_sup / len(pbar), epoch)
logger.add_scalar('train_loss_sup_r', sum_loss_sup_r / len(pbar), epoch)
logger.add_scalar('train_loss_cps', sum_cps_loss / len(pbar), epoch)
if azure and engine.local_rank == 0:
run.log(name='Supervised Training Loss', value=sum_loss_sup / len(pbar))
run.log(name='Supervised Training Loss right', value=sum_loss_sup_r / len(pbar))
run.log(name='Supervised Training Loss CPS', value=sum_cps_loss / len(pbar))
if (epoch > config.nepochs // 3) and (epoch % config.snapshot_iter == 0) or (epoch == config.nepochs - 1):
if engine.distributed and (engine.local_rank == 0):
engine.save_and_link_checkpoint(config.snapshot_dir,
config.log_dir,
config.log_dir_link)
elif not engine.distributed:
engine.save_and_link_checkpoint(config.snapshot_dir,
config.log_dir,
config.log_dir_link)