MODNet matting

一、要解决的问题

  • 无绿幕人像抠图
  • I = α ∗ F + ( 1 − α ) ∗ B I=\alpha *F+(1-\alpha)*B I=αF+(1α)B

二、创新点

  • 无绿幕、无trimap人像端到端抠图
  • SOC模型泛化迁移,OFD视频抠图增强
  • Validation Benchmark

三、具体细节

MODNet matting_第1张图片
MODNet网络结构如上图所示。主要包括三个自网络:Semantic Branch;Detail Branch;Fusion Branch。

Sematic Branch

Encoder-Decoder结构,采用Mobilenet-v2作为Encoder,并使用channel-wise Attention给Hidden Features添加权重,re-weight后的特征通过上采样-卷积-BN-ReLU套装恢复分辨率到原分辨率的1/8。Sigmoid激活后输出,作为Semantics S p S_p Sp

Detail Branch

Encoder-Decoder结构,输入包括Image以及Semantic Branch不同层的hidden features,通过上采样-卷积-BN-ReLU套装输出原分辨率的detail_alpha图

Fusion Branch

Encoder-Decoder结构,输入包括Semantic Branch以及Detail Branch的hidden features。通过上采样-卷积-BN-ReLU套装恢复到原分辨率,Sigmoid激活后输出,作为最终的 α \alpha α

四、代码分析

网络结构较为简单,不分析此部分代码。

看一下各部分的损失函数。

Semantic Branch的损失函数, G ( α g ) G(\alpha_g) G(αg)表示对gound truth alpha下采样。使用L2 Loss。
在这里插入图片描述
Detail Branch的损失函数,使用L1 Loss。 m d m_d md表示边缘区域。
在这里插入图片描述
Fusion Branch的损失函数,除了L1损失,还引入合成损失。
在这里插入图片描述
整个网络的损失函数:
在这里插入图片描述

# forward the model
pred_semantic, pred_detail, pred_matte = modnet(image, False)

# calculate the boundary mask from the trimap
boundaries = (trimap < 0.5) + (trimap > 0.5)

# calculate the semantic loss
gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear')
gt_semantic = blurer(gt_semantic)
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
semantic_loss = semantic_scale * semantic_loss

# calculate the detail loss
pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
gt_detail = torch.where(boundaries, trimap, gt_matte)
detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
detail_loss = detail_scale * detail_loss

# calculate the matte loss
pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
   + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
matte_loss = matte_scale * matte_loss

# calculate the final loss, backward the loss, and update the model 
loss = semantic_loss + detail_loss + matte_loss
loss.backward()
optimizer.step()

五、总结

MODnet结构清晰,优秀的训练数据是关键,可惜不开源。

你可能感兴趣的:(抠图Matting,深度学习,神经网络,计算机视觉,机器学习,人工智能)